[quic] more h3 tests
This commit is contained in:
parent
a2d962b09a
commit
201f03082a
|
@ -4,9 +4,18 @@ import pytest
|
|||
import pylsqpack
|
||||
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.h3.connection import FrameType, StreamType, Headers, Setting, encode_frame, encode_uint_var, encode_settings, parse_settings
|
||||
from aioquic.h3.connection import (
|
||||
FrameType,
|
||||
Headers,
|
||||
Setting,
|
||||
StreamType,
|
||||
encode_frame,
|
||||
encode_uint_var,
|
||||
encode_settings,
|
||||
parse_settings,
|
||||
)
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import connection, version
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.proxy import commands, context, layers
|
||||
from mitmproxy.proxy.layers import http, quic
|
||||
|
@ -20,6 +29,16 @@ example_request_headers = [
|
|||
(b":authority", b"example.com"),
|
||||
]
|
||||
|
||||
example_response_headers = [(b":status", b"200")]
|
||||
|
||||
example_response_trailers = [(b"resp-trailer-a", b"a"), (b"resp-trailer-b", b"b")]
|
||||
|
||||
|
||||
def decode_frame(frame_type: int, frame_data: bytes) -> bytes:
|
||||
buf = Buffer(data=frame_data)
|
||||
assert buf.pull_uint_var() == frame_type
|
||||
return buf.pull_bytes(buf.pull_uint_var())
|
||||
|
||||
|
||||
class CallbackPlaceholder(tutils._Placeholder[bytes]):
|
||||
"""Data placeholder that invokes a callback once its bytes get set."""
|
||||
|
@ -77,7 +96,7 @@ class FrameFactory:
|
|||
max_table_capacity=4096,
|
||||
blocked_streams=16,
|
||||
)
|
||||
self.decoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
|
||||
self.decoder_placeholders: list[tutils.Placeholder(bytes)] = []
|
||||
self.encoder = pylsqpack.Encoder()
|
||||
self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
|
||||
self.peer_stream_id: dict[StreamType, int] = {}
|
||||
|
@ -137,6 +156,21 @@ class FrameFactory:
|
|||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_max_push_id(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
|
||||
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
|
||||
self.max_push_id = buf.pull_uint_var()
|
||||
assert buf.eof()
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.CONTROL],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_settings(self) -> quic.SendQuicStreamData:
|
||||
assert self.encoder_placeholder is None
|
||||
placeholder = tutils.Placeholder(bytes)
|
||||
|
@ -158,21 +192,6 @@ class FrameFactory:
|
|||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_max_push_id(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
|
||||
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
|
||||
self.max_push_id = buf.pull_uint_var()
|
||||
assert buf.eof()
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.CONTROL],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_settings(
|
||||
self,
|
||||
settings: dict[int, int] = {
|
||||
|
@ -213,19 +232,6 @@ class FrameFactory:
|
|||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_data(
|
||||
self,
|
||||
data: bytes,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> quic.SendQuicStreamData:
|
||||
return quic.SendQuicStreamData(
|
||||
self.conn,
|
||||
stream_id=stream_id,
|
||||
data=encode_frame(FrameType.DATA, data),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_decoder(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> None:
|
||||
self.encoder.feed_decoder(data)
|
||||
|
@ -238,9 +244,8 @@ class FrameFactory:
|
|||
)
|
||||
|
||||
def receive_decoder(self) -> quic.QuicStreamDataReceived:
|
||||
assert self.decoder_placeholder is not None
|
||||
placeholder = self.decoder_placeholder
|
||||
self.decoder_placeholder = None
|
||||
assert self.decoder_placeholders
|
||||
placeholder = self.decoder_placeholders.pop(0)
|
||||
|
||||
return quic.QuicStreamDataReceived(
|
||||
self.conn,
|
||||
|
@ -249,6 +254,31 @@ class FrameFactory:
|
|||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> Iterable[quic.SendQuicStreamData]:
|
||||
placeholder = tutils.Placeholder(bytes)
|
||||
self.decoder_placeholders.append(placeholder)
|
||||
|
||||
def decode(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.HEADERS
|
||||
frame_data = buf.pull_bytes(buf.pull_uint_var())
|
||||
decoder, actual_headers = self.decoder.feed_header(stream_id, frame_data)
|
||||
placeholder.setdefault(decoder)
|
||||
assert headers == actual_headers
|
||||
|
||||
yield self.send_encoder()
|
||||
yield quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=CallbackPlaceholder(decode),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def receive_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
|
@ -275,29 +305,16 @@ class FrameFactory:
|
|||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_headers(
|
||||
def send_data(
|
||||
self,
|
||||
headers: Headers,
|
||||
data: bytes,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> Iterable[quic.SendQuicStreamData]:
|
||||
assert self.decoder_placeholder is None
|
||||
placeholder = tutils.Placeholder(bytes)
|
||||
self.decoder_placeholder = placeholder
|
||||
|
||||
def decode(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.HEADERS
|
||||
frame_data = buf.pull_bytes(buf.pull_uint_var())
|
||||
decoder, headers = self.decoder.feed_header(stream_id, frame_data)
|
||||
placeholder.setdefault(decoder)
|
||||
assert headers == headers
|
||||
|
||||
yield self.send_encoder()
|
||||
yield quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
) -> quic.SendQuicStreamData:
|
||||
return quic.SendQuicStreamData(
|
||||
self.conn,
|
||||
stream_id=stream_id,
|
||||
data=CallbackPlaceholder(decode),
|
||||
data=encode_frame(FrameType.DATA, data),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
|
@ -314,13 +331,27 @@ class FrameFactory:
|
|||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_server_init(self) -> Iterable[quic.SendQuicStreamData]:
|
||||
def send_init(self) -> Iterable[quic.SendQuicStreamData]:
|
||||
yield self.send_stream_type(StreamType.CONTROL)
|
||||
yield self.send_settings()
|
||||
yield self.send_max_push_id()
|
||||
if not self.is_client:
|
||||
yield self.send_max_push_id()
|
||||
yield self.send_stream_type(StreamType.QPACK_ENCODER)
|
||||
yield self.send_stream_type(StreamType.QPACK_DECODER)
|
||||
|
||||
def receive_init(self) -> Iterable[quic.QuicStreamDataReceived]:
|
||||
yield self.receive_stream_type(StreamType.CONTROL)
|
||||
yield self.receive_stream_type(StreamType.QPACK_ENCODER)
|
||||
yield self.receive_stream_type(StreamType.QPACK_DECODER)
|
||||
yield self.receive_settings()
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return (
|
||||
self.encoder_placeholder is None
|
||||
and not self.decoder_placeholders
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def open_h3_server_conn():
|
||||
|
@ -340,15 +371,9 @@ def start_h3_client(tctx: context.Context) -> tuple[tutils.Playbook, FrameFactor
|
|||
cff = FrameFactory(conn=tctx.client, is_client=True)
|
||||
assert (
|
||||
playbook
|
||||
<< cff.send_stream_type(StreamType.CONTROL)
|
||||
<< cff.send_settings()
|
||||
<< cff.send_stream_type(StreamType.QPACK_ENCODER)
|
||||
<< cff.send_stream_type(StreamType.QPACK_DECODER)
|
||||
>> cff.receive_stream_type(StreamType.CONTROL)
|
||||
>> cff.receive_settings()
|
||||
<< cff.send_init()
|
||||
>> cff.receive_init()
|
||||
<< cff.send_encoder()
|
||||
>> cff.receive_stream_type(StreamType.QPACK_ENCODER)
|
||||
>> cff.receive_stream_type(StreamType.QPACK_DECODER)
|
||||
>> cff.receive_encoder()
|
||||
)
|
||||
return playbook, cff
|
||||
|
@ -366,14 +391,151 @@ def test_simple(tctx: context.Context):
|
|||
sff = FrameFactory(server, is_client=False)
|
||||
assert (
|
||||
playbook
|
||||
# request client
|
||||
>> cff.receive_headers(example_request_headers, end_stream=True)
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
<< cff.send_decoder()
|
||||
>> tutils.reply(to=http.HttpRequestHeadersHook(flow))
|
||||
<< (request := http.HttpRequestHeadersHook(flow))
|
||||
<< cff.send_decoder() # for receive_headers
|
||||
>> tutils.reply(to=request)
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> tutils.reply()
|
||||
# request server
|
||||
<< commands.OpenConnection(server)
|
||||
>> tutils.reply(None, side_effect=make_h3)
|
||||
<< sff.send_server_init()
|
||||
<< sff.send_init()
|
||||
<< sff.send_headers(example_request_headers, end_stream=True)
|
||||
>> sff.receive_init()
|
||||
<< sff.send_encoder()
|
||||
>> sff.receive_encoder()
|
||||
>> sff.receive_decoder() # for send_headers
|
||||
# response server
|
||||
>> sff.receive_headers(example_response_headers)
|
||||
<< (response := http.HttpResponseHeadersHook(flow))
|
||||
<< sff.send_decoder() # for receive_headers
|
||||
>> tutils.reply(to=response)
|
||||
>> sff.receive_data(b"Hello, World!", end_stream=True)
|
||||
<< http.HttpResponseHook(flow)
|
||||
>> tutils.reply()
|
||||
# response client
|
||||
<< cff.send_headers(example_response_headers)
|
||||
<< cff.send_data(b"Hello, World!")
|
||||
<< cff.send_data(b"", end_stream=True)
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
)
|
||||
assert cff.is_done and sff.is_done
|
||||
assert flow().request.url == "http://example.com/"
|
||||
assert flow().response.text == "Hello, World!"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_response_trailers(
|
||||
tctx: context.Context,
|
||||
open_h3_server_conn: connection.Server,
|
||||
stream: bool,
|
||||
):
|
||||
playbook, cff = start_h3_client(tctx)
|
||||
tctx.server = open_h3_server_conn
|
||||
sff = FrameFactory(tctx.server, is_client=False)
|
||||
|
||||
def enable_streaming(flow: HTTPFlow):
|
||||
flow.response.stream = stream
|
||||
|
||||
flow = tutils.Placeholder(HTTPFlow)
|
||||
(
|
||||
playbook
|
||||
# request client
|
||||
>> cff.receive_headers(example_request_headers, end_stream=True)
|
||||
<< (request := http.HttpRequestHeadersHook(flow))
|
||||
<< cff.send_decoder() # for receive_headers
|
||||
>> tutils.reply(to=request)
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> tutils.reply()
|
||||
# request server
|
||||
<< sff.send_init()
|
||||
<< sff.send_headers(example_request_headers, end_stream=True)
|
||||
>> sff.receive_init()
|
||||
<< sff.send_encoder()
|
||||
>> sff.receive_encoder()
|
||||
>> sff.receive_decoder() # for send_headers
|
||||
# response server
|
||||
>> sff.receive_headers(example_response_headers)
|
||||
<< (response_headers := http.HttpResponseHeadersHook(flow))
|
||||
<< sff.send_decoder() # for receive_headers
|
||||
>> tutils.reply(to=response_headers, side_effect=enable_streaming)
|
||||
)
|
||||
if stream:
|
||||
(
|
||||
playbook
|
||||
<< cff.send_headers(example_response_headers)
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
>> sff.receive_data(b"Hello, World!")
|
||||
<< cff.send_data(b"Hello, World!")
|
||||
)
|
||||
else:
|
||||
playbook >> sff.receive_data(b"Hello, World!")
|
||||
assert (
|
||||
playbook
|
||||
>> sff.receive_headers(example_response_trailers, end_stream=True)
|
||||
<< (response := http.HttpResponseHook(flow))
|
||||
<< sff.send_decoder() # for receive_headers
|
||||
)
|
||||
|
||||
def modify_tailers(flow: HTTPFlow) -> None:
|
||||
assert flow.response.trailers
|
||||
del flow.response.trailers["resp-trailer-a"]
|
||||
|
||||
if stream:
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply(to=response, side_effect=modify_tailers)
|
||||
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply(to=response, side_effect=modify_tailers)
|
||||
<< cff.send_headers(example_response_headers)
|
||||
<< cff.send_data(b"Hello, World!")
|
||||
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
)
|
||||
assert cff.is_done and sff.is_done
|
||||
|
||||
|
||||
def test_upstream_error(tctx: context.Context):
|
||||
playbook, cff = start_h3_client(tctx)
|
||||
flow = tutils.Placeholder(HTTPFlow)
|
||||
server = tutils.Placeholder(connection.Server)
|
||||
err = tutils.Placeholder(bytes)
|
||||
assert (
|
||||
playbook
|
||||
# request client
|
||||
>> cff.receive_headers(example_request_headers, end_stream=True)
|
||||
<< (request := http.HttpRequestHeadersHook(flow))
|
||||
<< cff.send_decoder() # for receive_headers
|
||||
>> tutils.reply(to=request)
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> tutils.reply()
|
||||
# request server
|
||||
<< commands.OpenConnection(server)
|
||||
>> tutils.reply("oops server <> error")
|
||||
<< http.HttpErrorHook(flow)
|
||||
>> tutils.reply()
|
||||
<< cff.send_headers([
|
||||
(b":status", b"502"),
|
||||
(b'server', version.MITMPROXY.encode()),
|
||||
(b'content-type', b'text/html'),
|
||||
])
|
||||
<< quic.SendQuicStreamData(
|
||||
tctx.client,
|
||||
stream_id=0,
|
||||
data=err,
|
||||
end_stream=True,
|
||||
)
|
||||
>> cff.receive_decoder() # for send_headers
|
||||
)
|
||||
assert cff.is_done
|
||||
data = decode_frame(FrameType.DATA, err())
|
||||
assert b"502 Bad Gateway" in data
|
||||
assert b"server <> error" in data
|
||||
|
|
Loading…
Reference in New Issue