Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions s3proxy/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,38 +245,58 @@ async def _iter_multipart_plaintext(
range_start: int | None = None,
range_end: int | None = None,
) -> AsyncIterator[bytes]:
"""Yield decrypted plaintext for each part of a multipart-encrypted object.

Parts outside the requested byte range are skipped. Parts that partially
overlap the range are trimmed before yielding.
"""Yield decrypted plaintext for a multipart-encrypted object, one frame
at a time.

A client part may expand into several internal parts (separate S3 parts),
and each internal part is itself a sequence of independent AES-GCM frames
(nonce || ciphertext || tag). Every frame must be decrypted individually —
decrypting a whole internal part or client part as a single seal fails with
InvalidTag whenever it holds more than one frame. See crypto.FRAME_* and the
GET reader (_stream_internal_parts) which uses the same layout.

Fetching per frame also bounds peak memory to O(frame) instead of a whole
~50MB client part. Frames outside the requested plaintext range are skipped
(no fetch); frames that partially overlap are trimmed before yielding.
"""
sorted_parts = sorted(meta.parts, key=lambda p: p.part_number)
pt_offset = 0
ct_offset = 0

for part in sorted_parts:
part_pt_end = pt_offset + part.plaintext_size - 1
if part.internal_parts:
segments = [
(ip.plaintext_size, ip.ciphertext_size)
for ip in sorted(part.internal_parts, key=lambda p: p.internal_part_number)
]
else:
segments = [(part.plaintext_size, part.ciphertext_size)]

in_range = range_start is None or (
part_pt_end >= range_start and pt_offset <= range_end
)
for seg_pt_size, seg_ct_size in segments:
for fsize in crypto.ciphertext_frame_byte_sizes(seg_pt_size, seg_ct_size):
frame_pt_size = fsize - crypto.ENCRYPTION_OVERHEAD
frame_pt_end = pt_offset + frame_pt_size - 1

in_range = range_start is None or (
frame_pt_end >= range_start and pt_offset <= range_end
)

if in_range:
ct_end = ct_offset + part.ciphertext_size - 1
resp = await client.get_object(bucket, key, f"bytes={ct_offset}-{ct_end}")
async with resp["Body"] as body:
ciphertext = await body.read()
part_plaintext = crypto.decrypt(ciphertext, dek)
if in_range:
ct_end = ct_offset + fsize - 1
resp = await client.get_object(bucket, key, f"bytes={ct_offset}-{ct_end}")
async with resp["Body"] as body:
ciphertext = await body.read()
plaintext = crypto.decrypt(ciphertext, dek)

if range_start is not None:
trim_start = max(0, range_start - pt_offset)
trim_end = min(part.plaintext_size, range_end - pt_offset + 1)
part_plaintext = part_plaintext[trim_start:trim_end]
if range_start is not None:
trim_start = max(0, range_start - pt_offset)
trim_end = min(frame_pt_size, range_end - pt_offset + 1)
plaintext = plaintext[trim_start:trim_end]

yield part_plaintext
yield plaintext

pt_offset = part_pt_end + 1
ct_offset += part.ciphertext_size
pt_offset += frame_pt_size
ct_offset += fsize

async def _download_encrypted_multipart(
self,
Expand Down
118 changes: 118 additions & 0 deletions tests/unit/test_streaming_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from s3proxy.handlers.multipart import MultipartHandlerMixin
from s3proxy.handlers.objects.misc import MiscObjectMixin
from s3proxy.state import (
InternalPartMetadata,
MultipartMetadata,
PartMetadata,
load_multipart_metadata,
Expand Down Expand Up @@ -601,3 +602,120 @@ async def test_streaming_respects_metadata_replace_directive(
# Content is still correct
recovered = await handler._download_encrypted_multipart(mock_s3, "bucket", "dst", dst_meta)
assert recovered == src_plaintext


# ---------------------------------------------------------------------------
# _iter_multipart_plaintext — framed / multi-internal-part source (issue: copy
# of ScyllaDB backups failed with InvalidTag because client parts hold multiple
# internal parts, each of which is multiple AES-GCM frames)
# ---------------------------------------------------------------------------


class TestIterMultipartPlaintextFramed:
"""Reproduces the production layout: client parts whose internal parts each
span more than one frame. The old reader decrypted a whole client part as a
single seal and raised cryptography.exceptions.InvalidTag."""

def _build_framed_source(self, mock_s3_dek, frame_size):
"""Return (ciphertext_blob, parts) for a source object shaped like a real
multipart-encrypted backup: 2 client parts × 2 internal parts, and each
internal part holds 3 frames of `frame_size` plaintext (last frame short).

Frames are sized by crypto.FRAME_PLAINTEXT_SIZE (patched small in the test),
so building the ciphertext must use that same boundary.
"""
dek = mock_s3_dek
blob = bytearray()
parts = []
# deterministic plaintext, sliced into internal parts as we go
internal_pt = frame_size * 3 - 7 # 3 frames, last one short
pt_seed = bytes((i * 37) % 256 for i in range(internal_pt))
ct_offset = 0
internal_no = 1
full_plaintext = bytearray()
for client_no in range(1, 3):
ips = []
client_pt = 0
client_ct = 0
for _ in range(2): # 2 internal parts per client part
# frame the internal part exactly as the writer does
ip_ct = bytearray()
for off in range(0, internal_pt, frame_size):
frame_pt = pt_seed[off : off + frame_size]
ip_ct += crypto.encrypt(frame_pt, dek)
blob += ip_ct
full_plaintext += pt_seed
ips.append(
InternalPartMetadata(
internal_part_number=internal_no,
plaintext_size=internal_pt,
ciphertext_size=len(ip_ct),
etag=hashlib.md5(bytes(ip_ct)).hexdigest(),
)
)
client_pt += internal_pt
client_ct += len(ip_ct)
internal_no += 1
ct_offset += len(ip_ct)
parts.append(
PartMetadata(
part_number=client_no,
plaintext_size=client_pt,
ciphertext_size=client_ct,
etag="x",
internal_parts=ips,
)
)
return bytes(blob), parts, bytes(full_plaintext)

@pytest.mark.asyncio
async def test_full_roundtrip(self, mock_s3, settings, manager, monkeypatch):
# Small frame so multi-frame internal parts stay tiny and fast.
monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100)
handler = _make_misc_handler(settings, manager)
await mock_s3.create_bucket("b")

dek = crypto.generate_dek()
blob, parts, plaintext = self._build_framed_source(dek, 100)
await mock_s3.put_object("b", "src", blob)

meta = type("M", (), {"parts": parts})()
recovered = bytearray()
async for chunk in handler._iter_multipart_plaintext(mock_s3, "b", "src", meta, dek):
recovered += chunk
assert bytes(recovered) == plaintext

@pytest.mark.asyncio
async def test_range_roundtrip(self, mock_s3, settings, manager, monkeypatch):
monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100)
handler = _make_misc_handler(settings, manager)
await mock_s3.create_bucket("b")

dek = crypto.generate_dek()
blob, parts, plaintext = self._build_framed_source(dek, 100)
await mock_s3.put_object("b", "src", blob)

meta = type("M", (), {"parts": parts})()
# a range that starts mid-frame in the first internal part and ends
# mid-frame several internal parts later
start, end = 137, len(plaintext) - 211
recovered = bytearray()
async for chunk in handler._iter_multipart_plaintext(
mock_s3, "b", "src", meta, dek, range_start=start, range_end=end
):
recovered += chunk
assert bytes(recovered) == plaintext[start : end + 1]

@pytest.mark.asyncio
async def test_whole_client_part_seal_would_fail(self, mock_s3, settings, manager, monkeypatch):
"""Guard: decrypting a whole client part as one seal (the old behavior)
must raise — this is exactly the production InvalidTag."""
from cryptography.exceptions import InvalidTag

monkeypatch.setattr(crypto, "FRAME_PLAINTEXT_SIZE", 100)
dek = crypto.generate_dek()
blob, parts, _ = self._build_framed_source(dek, 100)
# first client part ciphertext = its internal parts concatenated
client0_ct = blob[: parts[0].ciphertext_size]
with pytest.raises(InvalidTag):
crypto.decrypt(client0_ct, dek)