diff --git a/s3proxy/handlers/multipart/upload_part.py b/s3proxy/handlers/multipart/upload_part.py index 7c49814..5beae20 100644 --- a/s3proxy/handlers/multipart/upload_part.py +++ b/s3proxy/handlers/multipart/upload_part.py @@ -7,7 +7,7 @@ import time from collections import deque from collections.abc import AsyncIterator -from typing import NoReturn +from typing import NamedTuple, NoReturn import structlog from botocore.exceptions import ClientError @@ -32,6 +32,34 @@ MAX_PARALLEL_INTERNAL_UPLOADS = 2 +class _UploadClass(NamedTuple): + is_unsigned: bool + is_streaming_sig: bool + needs_chunked_decode: bool + is_large_signed: bool + use_framed: bool + + +def classify_upload(content_sha: str, content_encoding: str, content_length: int) -> _UploadClass: + """Decide how an UploadPart body is read and encrypted. + + Any signed, known-length, non-chunked body streams frame-by-frame (framed + path, O(frame) memory). Only aws-chunked / streaming-signature bodies, whose + length is unknown up front, keep the buffered path. ``is_large_signed`` means + "signed and known-length"; it gates late SHA256 verification. + """ + is_unsigned = content_sha == "UNSIGNED-PAYLOAD" + is_streaming_sig = content_sha.startswith("STREAMING-") + needs_chunked_decode = "aws-chunked" in content_encoding or is_streaming_sig + is_large_signed = not is_unsigned and not is_streaming_sig and content_length > 0 + use_framed = ( + (is_unsigned or is_large_signed) and not needs_chunked_decode and content_length > 0 + ) + return _UploadClass( + is_unsigned, is_streaming_sig, needs_chunked_decode, is_large_signed, use_framed + ) + + class _PlaintextReader: """Pulls exactly-sized plaintext slices from a chunked byte stream. @@ -86,25 +114,19 @@ async def handle_upload_part(self, request: Request, creds: S3Credentials) -> Re content_length_mb=f"{content_length / 1024 / 1024:.2f}MB", ) - # Determine encoding type - is_unsigned = content_sha == "UNSIGNED-PAYLOAD" - is_streaming_sig = content_sha.startswith("STREAMING-") - needs_chunked_decode = "aws-chunked" in content_encoding or is_streaming_sig - is_large_signed = ( - not is_unsigned - and not is_streaming_sig - and content_length > crypto.STREAMING_THRESHOLD - ) + # Determine encoding type and upload path. + cls = classify_upload(content_sha, content_encoding, content_length) + is_unsigned = cls.is_unsigned + is_streaming_sig = cls.is_streaming_sig + needs_chunked_decode = cls.needs_chunked_decode + is_large_signed = cls.is_large_signed + use_framed = cls.use_framed # Smallest internal part that bounds memory while staying within the # per-client part-number allocation range (so we never collide and # never buffer more than necessary). internal_part_size = crypto.memory_bounded_part_size(content_length) estimated_parts = max(1, -(-content_length // internal_part_size)) - - use_framed = ( - (is_unsigned or is_large_signed) and not needs_chunked_decode and content_length > 0 - ) logger.info( "UPLOAD_PART_CONFIG", bucket=bucket, diff --git a/s3proxy/request_handler.py b/s3proxy/request_handler.py index c71fbe5..5ac7388 100644 --- a/s3proxy/request_handler.py +++ b/s3proxy/request_handler.py @@ -12,7 +12,7 @@ from fastapi.responses import PlainTextResponse from structlog.stdlib import BoundLogger -from . import concurrency, crypto +from . import concurrency from .client import ParsedRequest, SigV4Verifier from .dashboard import record_request from .errors import S3Error, raise_for_client_error, raise_for_exception @@ -29,10 +29,6 @@ pod_name = os.environ.get("HOSTNAME", "unknown") logger: BoundLogger = structlog.get_logger(__name__).bind(pod=pod_name) -# Signature verification constants -UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" -STREAMING_PAYLOAD_PREFIX = "STREAMING-" - def _is_dashboard_path(request: Request, path: str) -> bool: """True if the path targets the dashboard (so it's excluded from stats). @@ -51,23 +47,14 @@ def _is_dashboard_path(request: Request, path: str) -> bool: return path == prefix or path.startswith(prefix + "/") -def _needs_body_for_signature(headers: dict[str, str], max_size: int) -> bool: - """Check if body is needed for signature verification. +def _needs_body_for_signature(headers: dict[str, str]) -> bool: + """Body is needed only when x-amz-content-sha256 is absent. - Returns False for unsigned payloads, streaming signatures, or large bodies. + The verifier uses that header as the payload hash verbatim and only rehashes + the body as a fallback when it is missing. Buffering it otherwise just pins + the whole part in memory. """ - content_sha = headers.get("x-amz-content-sha256", "") - if content_sha == UNSIGNED_PAYLOAD or content_sha.startswith(STREAMING_PAYLOAD_PREFIX): - return False - - content_length = headers.get("content-length", "0") - try: - if int(content_length) > max_size: - return False - except ValueError: - pass - - return True + return headers.get("x-amz-content-sha256", "") == "" async def handle_proxy_request( @@ -190,9 +177,7 @@ async def _handle_proxy_request_impl( headers = {k.lower(): v for k, v in request.headers.items()} query = parse_qs(str(request.url.query), keep_blank_values=True) - needs_body = request.method in ("PUT", "POST") and _needs_body_for_signature( - headers, crypto.STREAMING_THRESHOLD - ) + needs_body = request.method in ("PUT", "POST") and _needs_body_for_signature(headers) content_length = headers.get("content-length", "0") body = await request.body() if needs_body else b"" if needs_body and len(body) > 0: diff --git a/tests/unit/test_routing.py b/tests/unit/test_routing.py index 0531cb0..7dc985f 100644 --- a/tests/unit/test_routing.py +++ b/tests/unit/test_routing.py @@ -42,45 +42,22 @@ def test_empty_path(self): class TestNeedsBodyForSignature: - """Test body requirement for signature verification.""" - - MAX_SIZE = 16 * 1024 * 1024 # 16MB default + """Body is buffered for signature only when x-amz-content-sha256 is absent.""" def test_unsigned_payload(self): - """Test UNSIGNED-PAYLOAD doesn't need body.""" - headers = {"x-amz-content-sha256": "UNSIGNED-PAYLOAD"} - assert _needs_body_for_signature(headers, self.MAX_SIZE) is False + assert _needs_body_for_signature({"x-amz-content-sha256": "UNSIGNED-PAYLOAD"}) is False def test_streaming_payload(self): - """Test streaming payload doesn't need body.""" headers = {"x-amz-content-sha256": "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"} - assert _needs_body_for_signature(headers, self.MAX_SIZE) is False - - def test_regular_payload(self): - """Test regular payload needs body.""" - headers = {"x-amz-content-sha256": "abc123def456"} - assert _needs_body_for_signature(headers, self.MAX_SIZE) is True - - def test_missing_header(self): - """Test missing header needs body.""" - headers = {} - assert _needs_body_for_signature(headers, self.MAX_SIZE) is True - - def test_large_content_length_skips_body(self): - """Test large content-length skips body buffering to avoid OOM.""" - headers = { - "x-amz-content-sha256": "abc123def456", - "content-length": str(self.MAX_SIZE + 1), - } - assert _needs_body_for_signature(headers, self.MAX_SIZE) is False - - def test_small_content_length_needs_body(self): - """Test small content-length still needs body.""" - headers = { - "x-amz-content-sha256": "abc123def456", - "content-length": str(self.MAX_SIZE - 1), - } - assert _needs_body_for_signature(headers, self.MAX_SIZE) is True + assert _needs_body_for_signature(headers) is False + + def test_signed_payload_skips_body_regardless_of_size(self): + headers = {"x-amz-content-sha256": "abc123def456", "content-length": str(16 * 1024 * 1024)} + assert _needs_body_for_signature(headers) is False + + def test_missing_header_needs_body(self): + assert _needs_body_for_signature({}) is True + assert _needs_body_for_signature({"x-amz-content-sha256": ""}) is True class TestQueryConstants: diff --git a/tests/unit/test_upload_path.py b/tests/unit/test_upload_path.py new file mode 100644 index 0000000..9676fcc --- /dev/null +++ b/tests/unit/test_upload_path.py @@ -0,0 +1,38 @@ +"""UploadPart path selection: signed known-length parts must stream, not buffer.""" + +from s3proxy.handlers.multipart.upload_part import classify_upload + +MB = 1024 * 1024 +SIGNED_SHA = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + +def test_signed_16mb_part_uses_framed_path(): + c = classify_upload(SIGNED_SHA, "", 16 * MB) + assert c.use_framed is True + assert c.is_large_signed is True + assert c.needs_chunked_decode is False + assert c.is_unsigned is False + + +def test_small_signed_part_uses_framed_path(): + assert classify_upload(SIGNED_SHA, "", 5 * MB).use_framed is True + + +def test_unsigned_large_part_uses_framed_path(): + c = classify_upload("UNSIGNED-PAYLOAD", "", 256 * MB) + assert c.use_framed is True and c.is_unsigned is True + + +def test_streaming_sig_uses_buffered_path(): + c = classify_upload("STREAMING-AWS4-HMAC-SHA256-PAYLOAD", "", 16 * MB) + assert c.use_framed is False + assert c.needs_chunked_decode is True + + +def test_aws_chunked_uses_buffered_path(): + c = classify_upload(SIGNED_SHA, "aws-chunked", 16 * MB) + assert c.use_framed is False and c.needs_chunked_decode is True + + +def test_zero_length_never_framed(): + assert classify_upload(SIGNED_SHA, "", 0).use_framed is False