[quic] first test for H3
This commit is contained in:
parent
78c1a23bf3
commit
a308d3dabc
|
@ -59,7 +59,7 @@ class Http3Connection(HttpConnection):
|
|||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
pass
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# send mitmproxy HTTP events over the H3 connection
|
||||
elif isinstance(event, HttpEvent):
|
||||
|
@ -145,7 +145,6 @@ class Http3Connection(HttpConnection):
|
|||
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
|
||||
reason_phrase=f"Invalid HTTP/3 request headers: {e}",
|
||||
)
|
||||
yield from self.h3_conn.transmit()
|
||||
else:
|
||||
yield ReceiveHttp(receive_event)
|
||||
if h3_event.stream_ended:
|
||||
|
@ -160,6 +159,7 @@ class Http3Connection(HttpConnection):
|
|||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# report a protocol error for all remaining open streams when a connection is closed
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
|
|
|
@ -0,0 +1,379 @@
|
|||
import collections.abc
|
||||
from typing import Callable, Iterable, Optional
|
||||
import pytest
|
||||
import pylsqpack
|
||||
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.h3.connection import FrameType, StreamType, Headers, Setting, encode_frame, encode_uint_var, encode_settings, parse_settings
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.proxy import commands, context, layers
|
||||
from mitmproxy.proxy.layers import http, quic
|
||||
from test.mitmproxy.proxy import tutils
|
||||
|
||||
|
||||
example_request_headers = [
|
||||
(b":method", b"GET"),
|
||||
(b":scheme", b"http"),
|
||||
(b":path", b"/"),
|
||||
(b":authority", b"example.com"),
|
||||
]
|
||||
|
||||
|
||||
class CallbackPlaceholder(tutils._Placeholder[bytes]):
|
||||
"""Data placeholder that invokes a callback once its bytes get set."""
|
||||
def __init__(self, cb: Callable[[bytes], None]):
|
||||
super().__init__(bytes)
|
||||
self._cb = cb
|
||||
|
||||
def setdefault(self, value: bytes) -> None:
|
||||
if self._obj is None:
|
||||
self._cb(value)
|
||||
return super().setdefault(value)
|
||||
|
||||
|
||||
class DelayedPlaceholder(tutils._Placeholder[bytes]):
|
||||
"""Data placeholder that resolves its bytes when needed."""
|
||||
def __init__(self, resolve: Callable[[], bytes]):
|
||||
super().__init__(bytes)
|
||||
self._resolve = resolve
|
||||
|
||||
def __call__(self) -> bytes:
|
||||
if self._obj is None:
|
||||
self._obj = self._resolve()
|
||||
return super().__call__()
|
||||
|
||||
|
||||
class MultiPlaybook(tutils.Playbook):
|
||||
"""Playbook that allows multiple events and commands to be registered at once."""
|
||||
def __lshift__(self, c):
|
||||
if isinstance(c, collections.abc.Iterable):
|
||||
for c_i in c:
|
||||
super().__lshift__(c_i)
|
||||
else:
|
||||
super().__lshift__(c)
|
||||
return self
|
||||
|
||||
def __rshift__(self, e):
|
||||
if isinstance(e, collections.abc.Iterable):
|
||||
for e_i in e:
|
||||
super().__rshift__(e_i)
|
||||
else:
|
||||
super().__rshift__(e)
|
||||
return self
|
||||
|
||||
|
||||
class FrameFactory:
|
||||
"""Helper class for generating QUIC stream events and commands."""
|
||||
def __init__(
|
||||
self,
|
||||
conn: connection.Connection,
|
||||
is_client: bool
|
||||
) -> None:
|
||||
self.conn = conn
|
||||
self.is_client = is_client
|
||||
self.decoder = pylsqpack.Decoder(
|
||||
max_table_capacity=4096,
|
||||
blocked_streams=16,
|
||||
)
|
||||
self.decoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
|
||||
self.encoder = pylsqpack.Encoder()
|
||||
self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
|
||||
self.peer_stream_id: dict[StreamType, int] = {}
|
||||
self.local_stream_id: dict[StreamType, int] = {}
|
||||
self.max_push_id: Optional[int] = None
|
||||
|
||||
def get_default_stream_id(
|
||||
self,
|
||||
stream_type: StreamType,
|
||||
for_local: bool
|
||||
) -> int:
|
||||
if stream_type == StreamType.CONTROL:
|
||||
stream_id = 2
|
||||
elif stream_type == StreamType.QPACK_ENCODER:
|
||||
stream_id = 6
|
||||
elif stream_type == StreamType.QPACK_DECODER:
|
||||
stream_id = 10
|
||||
else:
|
||||
raise AssertionError(stream_type)
|
||||
if self.is_client is not for_local:
|
||||
stream_id = stream_id + 1
|
||||
return stream_id
|
||||
|
||||
def send_stream_type(
|
||||
self,
|
||||
stream_type: StreamType,
|
||||
stream_id: Optional[int] = None,
|
||||
) -> quic.SendQuicStreamData:
|
||||
assert stream_type not in self.peer_stream_id
|
||||
if stream_id is None:
|
||||
stream_id = self.get_default_stream_id(
|
||||
stream_type, for_local=False
|
||||
)
|
||||
self.peer_stream_id[stream_type] = stream_id
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=encode_uint_var(stream_type),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_stream_type(
|
||||
self,
|
||||
stream_type: StreamType,
|
||||
stream_id: Optional[int] = None,
|
||||
) -> quic.QuicStreamDataReceived:
|
||||
assert stream_type not in self.local_stream_id
|
||||
if stream_id is None:
|
||||
stream_id = self.get_default_stream_id(
|
||||
stream_type, for_local=True
|
||||
)
|
||||
self.local_stream_id[stream_type] = stream_id
|
||||
return quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=encode_uint_var(stream_type),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_settings(self) -> quic.SendQuicStreamData:
|
||||
assert self.encoder_placeholder is None
|
||||
placeholder = tutils.Placeholder(bytes)
|
||||
self.encoder_placeholder = placeholder
|
||||
|
||||
def cb(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.SETTINGS
|
||||
settings = parse_settings(buf.pull_bytes(buf.pull_uint_var()))
|
||||
placeholder.setdefault(self.encoder.apply_settings(
|
||||
max_table_capacity=settings[Setting.QPACK_MAX_TABLE_CAPACITY],
|
||||
blocked_streams=settings[Setting.QPACK_BLOCKED_STREAMS],
|
||||
))
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.CONTROL],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_max_push_id(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
|
||||
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
|
||||
self.max_push_id = buf.pull_uint_var()
|
||||
assert buf.eof()
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.CONTROL],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_settings(
|
||||
self,
|
||||
settings: dict[int, int] = {
|
||||
Setting.QPACK_MAX_TABLE_CAPACITY: 4096,
|
||||
Setting.QPACK_BLOCKED_STREAMS: 16,
|
||||
Setting.ENABLE_CONNECT_PROTOCOL: 1,
|
||||
Setting.DUMMY: 1,
|
||||
},
|
||||
) -> quic.QuicStreamDataReceived:
|
||||
return quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=self.local_stream_id[StreamType.CONTROL],
|
||||
data=encode_frame(FrameType.SETTINGS, encode_settings(settings)),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_encoder(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> bytes:
|
||||
self.decoder.feed_encoder(data)
|
||||
return data
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.QPACK_ENCODER],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_encoder(self) -> quic.QuicStreamDataReceived:
|
||||
assert self.encoder_placeholder is not None
|
||||
placeholder = self.encoder_placeholder
|
||||
self.encoder_placeholder = None
|
||||
|
||||
return quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=self.local_stream_id[StreamType.QPACK_ENCODER],
|
||||
data=placeholder,
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def send_data(
|
||||
self,
|
||||
data: bytes,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> quic.SendQuicStreamData:
|
||||
return quic.SendQuicStreamData(
|
||||
self.conn,
|
||||
stream_id=stream_id,
|
||||
data=encode_frame(FrameType.DATA, data),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_decoder(self) -> quic.SendQuicStreamData:
|
||||
def cb(data: bytes) -> None:
|
||||
self.encoder.feed_decoder(data)
|
||||
|
||||
return quic.SendQuicStreamData(
|
||||
self.conn,
|
||||
stream_id=self.peer_stream_id[StreamType.QPACK_DECODER],
|
||||
data=CallbackPlaceholder(cb),
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_decoder(self) -> quic.QuicStreamDataReceived:
|
||||
assert self.decoder_placeholder is not None
|
||||
placeholder = self.decoder_placeholder
|
||||
self.decoder_placeholder = None
|
||||
|
||||
return quic.QuicStreamDataReceived(
|
||||
self.conn,
|
||||
stream_id=self.local_stream_id[StreamType.QPACK_DECODER],
|
||||
data=placeholder,
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
def receive_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> Iterable[quic.QuicStreamDataReceived]:
|
||||
data = tutils.Placeholder(bytes)
|
||||
|
||||
def encode() -> bytes:
|
||||
encoder, frame_data = self.encoder.encode(stream_id, headers)
|
||||
data.setdefault(encode_frame(FrameType.HEADERS, frame_data))
|
||||
return encoder
|
||||
|
||||
yield quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=self.local_stream_id[StreamType.QPACK_ENCODER],
|
||||
data=DelayedPlaceholder(encode),
|
||||
end_stream=False,
|
||||
)
|
||||
yield quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=data,
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> Iterable[quic.SendQuicStreamData]:
|
||||
assert self.decoder_placeholder is None
|
||||
placeholder = tutils.Placeholder(bytes)
|
||||
self.decoder_placeholder = placeholder
|
||||
|
||||
def decode(data: bytes) -> None:
|
||||
buf = Buffer(data=data)
|
||||
assert buf.pull_uint_var() == FrameType.HEADERS
|
||||
frame_data = buf.pull_bytes(buf.pull_uint_var())
|
||||
decoder, headers = self.decoder.feed_header(stream_id, frame_data)
|
||||
placeholder.setdefault(decoder)
|
||||
assert headers == headers
|
||||
|
||||
yield self.send_encoder()
|
||||
yield quic.SendQuicStreamData(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=CallbackPlaceholder(decode),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def receive_data(
|
||||
self,
|
||||
data: bytes,
|
||||
stream_id: int = 0,
|
||||
end_stream: bool = False,
|
||||
) -> quic.QuicStreamDataReceived:
|
||||
return quic.QuicStreamDataReceived(
|
||||
connection=self.conn,
|
||||
stream_id=stream_id,
|
||||
data=encode_frame(FrameType.DATA, data),
|
||||
end_stream=end_stream,
|
||||
)
|
||||
|
||||
def send_server_init(self) -> Iterable[quic.SendQuicStreamData]:
|
||||
yield self.send_stream_type(StreamType.CONTROL)
|
||||
yield self.send_settings()
|
||||
yield self.send_max_push_id()
|
||||
yield self.send_stream_type(StreamType.QPACK_ENCODER)
|
||||
yield self.send_stream_type(StreamType.QPACK_DECODER)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def open_h3_server_conn():
|
||||
# this is a bit fake here (port 80, with alpn, but no tls - c'mon),
|
||||
# but we don't want to pollute our tests with TLS handshakes.
|
||||
server = connection.Server(("example.com", 80), transport_protocol="udp")
|
||||
server.state = connection.ConnectionState.OPEN
|
||||
server.alpn = b"h3"
|
||||
return server
|
||||
|
||||
|
||||
def start_h3_client(tctx: context.Context) -> tuple[tutils.Playbook, FrameFactory]:
|
||||
tctx.client.alpn = b"h3"
|
||||
tctx.client.transport_protocol = "udp"
|
||||
|
||||
playbook = MultiPlaybook(layers.HttpLayer(tctx, layers.http.HTTPMode.regular))
|
||||
cff = FrameFactory(conn=tctx.client, is_client=True)
|
||||
assert (
|
||||
playbook
|
||||
<< cff.send_stream_type(StreamType.CONTROL)
|
||||
<< cff.send_settings()
|
||||
<< cff.send_stream_type(StreamType.QPACK_ENCODER)
|
||||
<< cff.send_stream_type(StreamType.QPACK_DECODER)
|
||||
>> cff.receive_stream_type(StreamType.CONTROL)
|
||||
>> cff.receive_settings()
|
||||
<< cff.send_encoder()
|
||||
>> cff.receive_stream_type(StreamType.QPACK_ENCODER)
|
||||
>> cff.receive_stream_type(StreamType.QPACK_DECODER)
|
||||
>> cff.receive_encoder()
|
||||
)
|
||||
return playbook, cff
|
||||
|
||||
|
||||
def make_h3(open_connection: commands.OpenConnection) -> None:
|
||||
open_connection.connection.alpn = b"h3"
|
||||
open_connection.connection.transport_protocol = "udp"
|
||||
|
||||
|
||||
def test_simple(tctx: context.Context):
|
||||
playbook, cff = start_h3_client(tctx)
|
||||
flow = tutils.Placeholder(HTTPFlow)
|
||||
server = tutils.Placeholder(connection.Server)
|
||||
sff = FrameFactory(server, is_client=False)
|
||||
assert (
|
||||
playbook
|
||||
>> cff.receive_headers(example_request_headers, end_stream=True)
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
<< cff.send_decoder()
|
||||
>> tutils.reply(to=http.HttpRequestHeadersHook(flow))
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> tutils.reply()
|
||||
<< commands.OpenConnection(server)
|
||||
>> tutils.reply(None, side_effect=make_h3)
|
||||
<< sff.send_server_init()
|
||||
<< sff.send_headers(example_request_headers, end_stream=True)
|
||||
)
|
Loading…
Reference in New Issue