diff --git a/test/mitmproxy/proxy/layers/http/test_http3.py b/test/mitmproxy/proxy/layers/http/test_http3.py index ece395971..00770a5c1 100644 --- a/test/mitmproxy/proxy/layers/http/test_http3.py +++ b/test/mitmproxy/proxy/layers/http/test_http3.py @@ -4,9 +4,18 @@ import pytest import pylsqpack from aioquic._buffer import Buffer -from aioquic.h3.connection import FrameType, StreamType, Headers, Setting, encode_frame, encode_uint_var, encode_settings, parse_settings +from aioquic.h3.connection import ( + FrameType, + Headers, + Setting, + StreamType, + encode_frame, + encode_uint_var, + encode_settings, + parse_settings, +) -from mitmproxy import connection +from mitmproxy import connection, version from mitmproxy.http import HTTPFlow from mitmproxy.proxy import commands, context, layers from mitmproxy.proxy.layers import http, quic @@ -20,6 +29,16 @@ example_request_headers = [ (b":authority", b"example.com"), ] +example_response_headers = [(b":status", b"200")] + +example_response_trailers = [(b"resp-trailer-a", b"a"), (b"resp-trailer-b", b"b")] + + +def decode_frame(frame_type: int, frame_data: bytes) -> bytes: + buf = Buffer(data=frame_data) + assert buf.pull_uint_var() == frame_type + return buf.pull_bytes(buf.pull_uint_var()) + class CallbackPlaceholder(tutils._Placeholder[bytes]): """Data placeholder that invokes a callback once its bytes get set.""" @@ -77,7 +96,7 @@ class FrameFactory: max_table_capacity=4096, blocked_streams=16, ) - self.decoder_placeholder: Optional[tutils.Placeholder(bytes)] = None + self.decoder_placeholders: list[tutils.Placeholder(bytes)] = [] self.encoder = pylsqpack.Encoder() self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None self.peer_stream_id: dict[StreamType, int] = {} @@ -137,6 +156,21 @@ class FrameFactory: end_stream=False, ) + def send_max_push_id(self) -> quic.SendQuicStreamData: + def cb(data: bytes) -> None: + buf = Buffer(data=data) + assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID + buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var())) + self.max_push_id = buf.pull_uint_var() + assert buf.eof() + + return quic.SendQuicStreamData( + connection=self.conn, + stream_id=self.peer_stream_id[StreamType.CONTROL], + data=CallbackPlaceholder(cb), + end_stream=False, + ) + def send_settings(self) -> quic.SendQuicStreamData: assert self.encoder_placeholder is None placeholder = tutils.Placeholder(bytes) @@ -158,21 +192,6 @@ class FrameFactory: end_stream=False, ) - def send_max_push_id(self) -> quic.SendQuicStreamData: - def cb(data: bytes) -> None: - buf = Buffer(data=data) - assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID - buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var())) - self.max_push_id = buf.pull_uint_var() - assert buf.eof() - - return quic.SendQuicStreamData( - connection=self.conn, - stream_id=self.peer_stream_id[StreamType.CONTROL], - data=CallbackPlaceholder(cb), - end_stream=False, - ) - def receive_settings( self, settings: dict[int, int] = { @@ -213,19 +232,6 @@ class FrameFactory: end_stream=False, ) - def send_data( - self, - data: bytes, - stream_id: int = 0, - end_stream: bool = False, - ) -> quic.SendQuicStreamData: - return quic.SendQuicStreamData( - self.conn, - stream_id=stream_id, - data=encode_frame(FrameType.DATA, data), - end_stream=end_stream, - ) - def send_decoder(self) -> quic.SendQuicStreamData: def cb(data: bytes) -> None: self.encoder.feed_decoder(data) @@ -238,9 +244,8 @@ class FrameFactory: ) def receive_decoder(self) -> quic.QuicStreamDataReceived: - assert self.decoder_placeholder is not None - placeholder = self.decoder_placeholder - self.decoder_placeholder = None + assert self.decoder_placeholders + placeholder = self.decoder_placeholders.pop(0) return quic.QuicStreamDataReceived( self.conn, @@ -249,6 +254,31 @@ class FrameFactory: end_stream=False, ) + def send_headers( + self, + headers: Headers, + stream_id: int = 0, + end_stream: bool = False, + ) -> Iterable[quic.SendQuicStreamData]: + placeholder = tutils.Placeholder(bytes) + self.decoder_placeholders.append(placeholder) + + def decode(data: bytes) -> None: + buf = Buffer(data=data) + assert buf.pull_uint_var() == FrameType.HEADERS + frame_data = buf.pull_bytes(buf.pull_uint_var()) + decoder, actual_headers = self.decoder.feed_header(stream_id, frame_data) + placeholder.setdefault(decoder) + assert headers == actual_headers + + yield self.send_encoder() + yield quic.SendQuicStreamData( + connection=self.conn, + stream_id=stream_id, + data=CallbackPlaceholder(decode), + end_stream=end_stream, + ) + def receive_headers( self, headers: Headers, @@ -275,29 +305,16 @@ class FrameFactory: end_stream=end_stream, ) - def send_headers( + def send_data( self, - headers: Headers, + data: bytes, stream_id: int = 0, end_stream: bool = False, - ) -> Iterable[quic.SendQuicStreamData]: - assert self.decoder_placeholder is None - placeholder = tutils.Placeholder(bytes) - self.decoder_placeholder = placeholder - - def decode(data: bytes) -> None: - buf = Buffer(data=data) - assert buf.pull_uint_var() == FrameType.HEADERS - frame_data = buf.pull_bytes(buf.pull_uint_var()) - decoder, headers = self.decoder.feed_header(stream_id, frame_data) - placeholder.setdefault(decoder) - assert headers == headers - - yield self.send_encoder() - yield quic.SendQuicStreamData( - connection=self.conn, + ) -> quic.SendQuicStreamData: + return quic.SendQuicStreamData( + self.conn, stream_id=stream_id, - data=CallbackPlaceholder(decode), + data=encode_frame(FrameType.DATA, data), end_stream=end_stream, ) @@ -314,13 +331,27 @@ class FrameFactory: end_stream=end_stream, ) - def send_server_init(self) -> Iterable[quic.SendQuicStreamData]: + def send_init(self) -> Iterable[quic.SendQuicStreamData]: yield self.send_stream_type(StreamType.CONTROL) yield self.send_settings() - yield self.send_max_push_id() + if not self.is_client: + yield self.send_max_push_id() yield self.send_stream_type(StreamType.QPACK_ENCODER) yield self.send_stream_type(StreamType.QPACK_DECODER) + def receive_init(self) -> Iterable[quic.QuicStreamDataReceived]: + yield self.receive_stream_type(StreamType.CONTROL) + yield self.receive_stream_type(StreamType.QPACK_ENCODER) + yield self.receive_stream_type(StreamType.QPACK_DECODER) + yield self.receive_settings() + + @property + def is_done(self) -> bool: + return ( + self.encoder_placeholder is None + and not self.decoder_placeholders + ) + @pytest.fixture def open_h3_server_conn(): @@ -340,15 +371,9 @@ def start_h3_client(tctx: context.Context) -> tuple[tutils.Playbook, FrameFactor cff = FrameFactory(conn=tctx.client, is_client=True) assert ( playbook - << cff.send_stream_type(StreamType.CONTROL) - << cff.send_settings() - << cff.send_stream_type(StreamType.QPACK_ENCODER) - << cff.send_stream_type(StreamType.QPACK_DECODER) - >> cff.receive_stream_type(StreamType.CONTROL) - >> cff.receive_settings() + << cff.send_init() + >> cff.receive_init() << cff.send_encoder() - >> cff.receive_stream_type(StreamType.QPACK_ENCODER) - >> cff.receive_stream_type(StreamType.QPACK_DECODER) >> cff.receive_encoder() ) return playbook, cff @@ -366,14 +391,151 @@ def test_simple(tctx: context.Context): sff = FrameFactory(server, is_client=False) assert ( playbook + # request client >> cff.receive_headers(example_request_headers, end_stream=True) - << http.HttpRequestHeadersHook(flow) - << cff.send_decoder() - >> tutils.reply(to=http.HttpRequestHeadersHook(flow)) + << (request := http.HttpRequestHeadersHook(flow)) + << cff.send_decoder() # for receive_headers + >> tutils.reply(to=request) << http.HttpRequestHook(flow) >> tutils.reply() + # request server << commands.OpenConnection(server) >> tutils.reply(None, side_effect=make_h3) - << sff.send_server_init() + << sff.send_init() << sff.send_headers(example_request_headers, end_stream=True) + >> sff.receive_init() + << sff.send_encoder() + >> sff.receive_encoder() + >> sff.receive_decoder() # for send_headers + # response server + >> sff.receive_headers(example_response_headers) + << (response := http.HttpResponseHeadersHook(flow)) + << sff.send_decoder() # for receive_headers + >> tutils.reply(to=response) + >> sff.receive_data(b"Hello, World!", end_stream=True) + << http.HttpResponseHook(flow) + >> tutils.reply() + # response client + << cff.send_headers(example_response_headers) + << cff.send_data(b"Hello, World!") + << cff.send_data(b"", end_stream=True) + >> cff.receive_decoder() # for send_headers ) + assert cff.is_done and sff.is_done + assert flow().request.url == "http://example.com/" + assert flow().response.text == "Hello, World!" + + +@pytest.mark.parametrize("stream", [True, False]) +def test_response_trailers( + tctx: context.Context, + open_h3_server_conn: connection.Server, + stream: bool, +): + playbook, cff = start_h3_client(tctx) + tctx.server = open_h3_server_conn + sff = FrameFactory(tctx.server, is_client=False) + + def enable_streaming(flow: HTTPFlow): + flow.response.stream = stream + + flow = tutils.Placeholder(HTTPFlow) + ( + playbook + # request client + >> cff.receive_headers(example_request_headers, end_stream=True) + << (request := http.HttpRequestHeadersHook(flow)) + << cff.send_decoder() # for receive_headers + >> tutils.reply(to=request) + << http.HttpRequestHook(flow) + >> tutils.reply() + # request server + << sff.send_init() + << sff.send_headers(example_request_headers, end_stream=True) + >> sff.receive_init() + << sff.send_encoder() + >> sff.receive_encoder() + >> sff.receive_decoder() # for send_headers + # response server + >> sff.receive_headers(example_response_headers) + << (response_headers := http.HttpResponseHeadersHook(flow)) + << sff.send_decoder() # for receive_headers + >> tutils.reply(to=response_headers, side_effect=enable_streaming) + ) + if stream: + ( + playbook + << cff.send_headers(example_response_headers) + >> cff.receive_decoder() # for send_headers + >> sff.receive_data(b"Hello, World!") + << cff.send_data(b"Hello, World!") + ) + else: + playbook >> sff.receive_data(b"Hello, World!") + assert ( + playbook + >> sff.receive_headers(example_response_trailers, end_stream=True) + << (response := http.HttpResponseHook(flow)) + << sff.send_decoder() # for receive_headers + ) + + def modify_tailers(flow: HTTPFlow) -> None: + assert flow.response.trailers + del flow.response.trailers["resp-trailer-a"] + + if stream: + assert ( + playbook + >> tutils.reply(to=response, side_effect=modify_tailers) + << cff.send_headers(example_response_trailers[1:], end_stream=True) + >> cff.receive_decoder() # for send_headers + ) + else: + assert ( + playbook + >> tutils.reply(to=response, side_effect=modify_tailers) + << cff.send_headers(example_response_headers) + << cff.send_data(b"Hello, World!") + << cff.send_headers(example_response_trailers[1:], end_stream=True) + >> cff.receive_decoder() # for send_headers + >> cff.receive_decoder() # for send_headers + ) + assert cff.is_done and sff.is_done + + +def test_upstream_error(tctx: context.Context): + playbook, cff = start_h3_client(tctx) + flow = tutils.Placeholder(HTTPFlow) + server = tutils.Placeholder(connection.Server) + err = tutils.Placeholder(bytes) + assert ( + playbook + # request client + >> cff.receive_headers(example_request_headers, end_stream=True) + << (request := http.HttpRequestHeadersHook(flow)) + << cff.send_decoder() # for receive_headers + >> tutils.reply(to=request) + << http.HttpRequestHook(flow) + >> tutils.reply() + # request server + << commands.OpenConnection(server) + >> tutils.reply("oops server <> error") + << http.HttpErrorHook(flow) + >> tutils.reply() + << cff.send_headers([ + (b":status", b"502"), + (b'server', version.MITMPROXY.encode()), + (b'content-type', b'text/html'), + ]) + << quic.SendQuicStreamData( + tctx.client, + stream_id=0, + data=err, + end_stream=True, + ) + >> cff.receive_decoder() # for send_headers + ) + assert cff.is_done + data = decode_frame(FrameType.DATA, err()) + assert b"502 Bad Gateway" in data + assert b"server <> error" in data