From bde577aabc16384dd9ab37f3d845e6ab2e0dcbcc Mon Sep 17 00:00:00 2001 From: Andras Spitzer Date: Thu, 13 Jun 2024 05:23:33 -0400 Subject: [PATCH] fix zstd decompression (#6921) * fix zstd decompression (issue #6914) * add our fix to CHANGELOG * add explicit read_across_frames=True + move zstd test to test_encoding.py --------- Co-authored-by: Maximilian Hils --- CHANGELOG.md | 3 ++- mitmproxy/net/encoding.py | 7 +------ test/mitmproxy/net/test_encoding.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7c689da1..aa2562ef2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ ## Unreleased: mitmproxy next - +* Fix zstd decompression to read across frames. + ([#6921](https://github.com/mitmproxy/mitmproxy/pull/6921), @zendai) ## 12 June 2024: mitmproxy 10.3.1 diff --git a/mitmproxy/net/encoding.py b/mitmproxy/net/encoding.py index 1b0b07185..95fd23bf4 100644 --- a/mitmproxy/net/encoding.py +++ b/mitmproxy/net/encoding.py @@ -174,12 +174,7 @@ def decode_zstd(content: bytes) -> bytes: if not content: return b"" zstd_ctx = zstd.ZstdDecompressor() - try: - return zstd_ctx.decompress(content) - except zstd.ZstdError: - # If the zstd stream is streamed without a size header, - # try decoding with a 10MiB output buffer - return zstd_ctx.decompress(content, max_output_size=10 * 2**20) + return zstd_ctx.stream_reader(BytesIO(content), read_across_frames=True).read() def encode_zstd(content: bytes) -> bytes: diff --git a/test/mitmproxy/net/test_encoding.py b/test/mitmproxy/net/test_encoding.py index 640d318ae..f3b79523a 100644 --- a/test/mitmproxy/net/test_encoding.py +++ b/test/mitmproxy/net/test_encoding.py @@ -93,3 +93,22 @@ def test_cache(): # This is not in the cache anymore assert encoding.encode(b"decoded", "gzip") == b"encoded" assert encode_gzip.call_count == 1 + + +def test_zstd(): + FRAME_SIZE = 1024 + + # Create payload of 1024b + test_content = "a" * FRAME_SIZE + + # Compress it, will result a single frame + single_frame = encoding.encode_zstd(test_content.encode()) + + # Concat compressed frame, it'll result two frames, total size of 2048b payload + two_frames = single_frame + single_frame + + # Uncompressed single frame should have the size of FRAME_SIZE + assert len(encoding.decode_zstd(single_frame)) == FRAME_SIZE + + # Uncompressed two frames should have the size of FRAME_SIZE * 2 + assert len(encoding.decode_zstd(two_frames)) == FRAME_SIZE * 2