diff --git a/s3proxy/handlers/base.py b/s3proxy/handlers/base.py index 97e4e99..34740bb 100644 --- a/s3proxy/handlers/base.py +++ b/s3proxy/handlers/base.py @@ -245,38 +245,58 @@ async def _iter_multipart_plaintext( range_start: int | None = None, range_end: int | None = None, ) -> AsyncIterator[bytes]: - """Yield decrypted plaintext for each part of a multipart-encrypted object. - - Parts outside the requested byte range are skipped. Parts that partially - overlap the range are trimmed before yielding. + """Yield decrypted plaintext for a multipart-encrypted object, one frame + at a time. + + A client part may expand into several internal parts (separate S3 parts), + and each internal part is itself a sequence of independent AES-GCM frames + (nonce || ciphertext || tag). Every frame must be decrypted individually — + decrypting a whole internal part or client part as a single seal fails with + InvalidTag whenever it holds more than one frame. See crypto.FRAME_* and the + GET reader (_stream_internal_parts) which uses the same layout. + + Fetching per frame also bounds peak memory to O(frame) instead of a whole + ~50MB client part. Frames outside the requested plaintext range are skipped + (no fetch); frames that partially overlap are trimmed before yielding. """ sorted_parts = sorted(meta.parts, key=lambda p: p.part_number) pt_offset = 0 ct_offset = 0 for part in sorted_parts: - part_pt_end = pt_offset + part.plaintext_size - 1 + if part.internal_parts: + segments = [ + (ip.plaintext_size, ip.ciphertext_size) + for ip in sorted(part.internal_parts, key=lambda p: p.internal_part_number) + ] + else: + segments = [(part.plaintext_size, part.ciphertext_size)] - in_range = range_start is None or ( - part_pt_end >= range_start and pt_offset <= range_end - ) + for seg_pt_size, seg_ct_size in segments: + for fsize in crypto.ciphertext_frame_byte_sizes(seg_pt_size, seg_ct_size): + frame_pt_size = fsize - crypto.ENCRYPTION_OVERHEAD + frame_pt_end = pt_offset + frame_pt_size - 1 + + in_range = range_start is None or ( + frame_pt_end >= range_start and pt_offset <= range_end + ) - if in_range: - ct_end = ct_offset + part.ciphertext_size - 1 - resp = await client.get_object(bucket, key, f"bytes={ct_offset}-{ct_end}") - async with resp["Body"] as body: - ciphertext = await body.read() - part_plaintext = crypto.decrypt(ciphertext, dek) + if in_range: + ct_end = ct_offset + fsize - 1 + resp = await client.get_object(bucket, key, f"bytes={ct_offset}-{ct_end}") + async with resp["Body"] as body: + ciphertext = await body.read() + plaintext = crypto.decrypt(ciphertext, dek) - if range_start is not None: - trim_start = max(0, range_start - pt_offset) - trim_end = min(part.plaintext_size, range_end - pt_offset + 1) - part_plaintext = part_plaintext[trim_start:trim_end] + if range_start is not None: + trim_start = max(0, range_start - pt_offset) + trim_end = min(frame_pt_size, range_end - pt_offset + 1) + plaintext = plaintext[trim_start:trim_end] - yield part_plaintext + yield plaintext - pt_offset = part_pt_end + 1 - ct_offset += part.ciphertext_size + pt_offset += frame_pt_size + ct_offset += fsize async def _download_encrypted_multipart( self, diff --git a/tests/unit/test_streaming_copy.py b/tests/unit/test_streaming_copy.py index 21269ff..f023d89 100644 --- a/tests/unit/test_streaming_copy.py +++ b/tests/unit/test_streaming_copy.py @@ -14,6 +14,7 @@ from s3proxy.handlers.multipart import MultipartHandlerMixin from s3proxy.handlers.objects.misc import MiscObjectMixin from s3proxy.state import ( + InternalPartMetadata, MultipartMetadata, PartMetadata, load_multipart_metadata, @@ -601,3 +602,120 @@ async def test_streaming_respects_metadata_replace_directive( # Content is still correct recovered = await handler._download_encrypted_multipart(mock_s3, "bucket", "dst", dst_meta) assert recovered == src_plaintext + + +# --------------------------------------------------------------------------- +# _iter_multipart_plaintext — framed / multi-internal-part source (issue: copy +# of ScyllaDB backups failed with InvalidTag because client parts hold multiple +# internal parts, each of which is multiple AES-GCM frames) +# --------------------------------------------------------------------------- + + +class TestIterMultipartPlaintextFramed: + """Reproduces the production layout: client parts whose internal parts each + span more than one frame. The old reader decrypted a whole client part as a + single seal and raised cryptography.exceptions.InvalidTag.""" + + def _build_framed_source(self, mock_s3_dek, frame_size): + """Return (ciphertext_blob, parts) for a source object shaped like a real + multipart-encrypted backup: 2 client parts × 2 internal parts, and each + internal part holds 3 frames of `frame_size` plaintext (last frame short). + + Frames are sized by crypto.FRAME_PLAINTEXT_SIZE (patched small in the test), + so building the ciphertext must use that same boundary. + """ + dek = mock_s3_dek + blob = bytearray() + parts = [] + # deterministic plaintext, sliced into internal parts as we go + internal_pt = frame_size * 3 - 7 # 3 frames, last one short + pt_seed = bytes((i * 37) % 256 for i in range(internal_pt)) + ct_offset = 0 + internal_no = 1 + full_plaintext = bytearray() + for client_no in range(1, 3): + ips = [] + client_pt = 0 + client_ct = 0 + for _ in range(2): # 2 internal parts per client part + # frame the internal part exactly as the writer does + ip_ct = bytearray() + for off in range(0, internal_pt, frame_size): + frame_pt = pt_seed[off : off + frame_size] + ip_ct += crypto.encrypt(frame_pt, dek) + blob += ip_ct + full_plaintext += pt_seed + ips.append( + InternalPartMetadata( + internal_part_number=internal_no, + plaintext_size=internal_pt, + ciphertext_size=len(ip_ct), + etag=hashlib.md5(bytes(ip_ct)).hexdigest(), + ) + ) + client_pt += internal_pt + client_ct += len(ip_ct) + internal_no += 1 + ct_offset += len(ip_ct) + parts.append( + PartMetadata( + part_number=client_no, + plaintext_size=client_pt, + ciphertext_size=client_ct, + etag="x", + internal_parts=ips, + ) + ) + return bytes(blob), parts, bytes(full_plaintext) + + @pytest.mark.asyncio + async def test_full_roundtrip(self, mock_s3, settings, manager, monkeypatch): + # Small frame so multi-frame internal parts stay tiny and fast. + monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100) + handler = _make_misc_handler(settings, manager) + await mock_s3.create_bucket("b") + + dek = crypto.generate_dek() + blob, parts, plaintext = self._build_framed_source(dek, 100) + await mock_s3.put_object("b", "src", blob) + + meta = type("M", (), {"parts": parts})() + recovered = bytearray() + async for chunk in handler._iter_multipart_plaintext(mock_s3, "b", "src", meta, dek): + recovered += chunk + assert bytes(recovered) == plaintext + + @pytest.mark.asyncio + async def test_range_roundtrip(self, mock_s3, settings, manager, monkeypatch): + monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100) + handler = _make_misc_handler(settings, manager) + await mock_s3.create_bucket("b") + + dek = crypto.generate_dek() + blob, parts, plaintext = self._build_framed_source(dek, 100) + await mock_s3.put_object("b", "src", blob) + + meta = type("M", (), {"parts": parts})() + # a range that starts mid-frame in the first internal part and ends + # mid-frame several internal parts later + start, end = 137, len(plaintext) - 211 + recovered = bytearray() + async for chunk in handler._iter_multipart_plaintext( + mock_s3, "b", "src", meta, dek, range_start=start, range_end=end + ): + recovered += chunk + assert bytes(recovered) == plaintext[start : end + 1] + + @pytest.mark.asyncio + async def test_whole_client_part_seal_would_fail(self, mock_s3, settings, manager, monkeypatch): + """Guard: decrypting a whole client part as one seal (the old behavior) + must raise — this is exactly the production InvalidTag.""" + from cryptography.exceptions import InvalidTag + + monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100) + dek = crypto.generate_dek() + blob, parts, _ = self._build_framed_source(dek, 100) + # first client part ciphertext = its internal parts concatenated + client0_ct = blob[: parts[0].ciphertext_size] + with pytest.raises(InvalidTag): + crypto.decrypt(client0_ct, dek)