[quic] more h3 tests

This commit is contained in:
Manuel Meitinger 2022-11-07 01:22:59 +01:00
parent a2d962b09a
commit 201f03082a
1 changed files with 229 additions and 67 deletions

View File

@ -4,9 +4,18 @@ import pytest
import pylsqpack import pylsqpack
from aioquic._buffer import Buffer from aioquic._buffer import Buffer
from aioquic.h3.connection import FrameType, StreamType, Headers, Setting, encode_frame, encode_uint_var, encode_settings, parse_settings from aioquic.h3.connection import (
FrameType,
Headers,
Setting,
StreamType,
encode_frame,
encode_uint_var,
encode_settings,
parse_settings,
)
from mitmproxy import connection from mitmproxy import connection, version
from mitmproxy.http import HTTPFlow from mitmproxy.http import HTTPFlow
from mitmproxy.proxy import commands, context, layers from mitmproxy.proxy import commands, context, layers
from mitmproxy.proxy.layers import http, quic from mitmproxy.proxy.layers import http, quic
@ -20,6 +29,16 @@ example_request_headers = [
(b":authority", b"example.com"), (b":authority", b"example.com"),
] ]
example_response_headers = [(b":status", b"200")]
example_response_trailers = [(b"resp-trailer-a", b"a"), (b"resp-trailer-b", b"b")]
def decode_frame(frame_type: int, frame_data: bytes) -> bytes:
buf = Buffer(data=frame_data)
assert buf.pull_uint_var() == frame_type
return buf.pull_bytes(buf.pull_uint_var())
class CallbackPlaceholder(tutils._Placeholder[bytes]): class CallbackPlaceholder(tutils._Placeholder[bytes]):
"""Data placeholder that invokes a callback once its bytes get set.""" """Data placeholder that invokes a callback once its bytes get set."""
@ -77,7 +96,7 @@ class FrameFactory:
max_table_capacity=4096, max_table_capacity=4096,
blocked_streams=16, blocked_streams=16,
) )
self.decoder_placeholder: Optional[tutils.Placeholder(bytes)] = None self.decoder_placeholders: list[tutils.Placeholder(bytes)] = []
self.encoder = pylsqpack.Encoder() self.encoder = pylsqpack.Encoder()
self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
self.peer_stream_id: dict[StreamType, int] = {} self.peer_stream_id: dict[StreamType, int] = {}
@ -137,6 +156,21 @@ class FrameFactory:
end_stream=False, end_stream=False,
) )
def send_max_push_id(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
self.max_push_id = buf.pull_uint_var()
assert buf.eof()
return quic.SendQuicStreamData(
connection=self.conn,
stream_id=self.peer_stream_id[StreamType.CONTROL],
data=CallbackPlaceholder(cb),
end_stream=False,
)
def send_settings(self) -> quic.SendQuicStreamData: def send_settings(self) -> quic.SendQuicStreamData:
assert self.encoder_placeholder is None assert self.encoder_placeholder is None
placeholder = tutils.Placeholder(bytes) placeholder = tutils.Placeholder(bytes)
@ -158,21 +192,6 @@ class FrameFactory:
end_stream=False, end_stream=False,
) )
def send_max_push_id(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
self.max_push_id = buf.pull_uint_var()
assert buf.eof()
return quic.SendQuicStreamData(
connection=self.conn,
stream_id=self.peer_stream_id[StreamType.CONTROL],
data=CallbackPlaceholder(cb),
end_stream=False,
)
def receive_settings( def receive_settings(
self, self,
settings: dict[int, int] = { settings: dict[int, int] = {
@ -213,19 +232,6 @@ class FrameFactory:
end_stream=False, end_stream=False,
) )
def send_data(
self,
data: bytes,
stream_id: int = 0,
end_stream: bool = False,
) -> quic.SendQuicStreamData:
return quic.SendQuicStreamData(
self.conn,
stream_id=stream_id,
data=encode_frame(FrameType.DATA, data),
end_stream=end_stream,
)
def send_decoder(self) -> quic.SendQuicStreamData: def send_decoder(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None: def cb(data: bytes) -> None:
self.encoder.feed_decoder(data) self.encoder.feed_decoder(data)
@ -238,9 +244,8 @@ class FrameFactory:
) )
def receive_decoder(self) -> quic.QuicStreamDataReceived: def receive_decoder(self) -> quic.QuicStreamDataReceived:
assert self.decoder_placeholder is not None assert self.decoder_placeholders
placeholder = self.decoder_placeholder placeholder = self.decoder_placeholders.pop(0)
self.decoder_placeholder = None
return quic.QuicStreamDataReceived( return quic.QuicStreamDataReceived(
self.conn, self.conn,
@ -249,6 +254,31 @@ class FrameFactory:
end_stream=False, end_stream=False,
) )
def send_headers(
self,
headers: Headers,
stream_id: int = 0,
end_stream: bool = False,
) -> Iterable[quic.SendQuicStreamData]:
placeholder = tutils.Placeholder(bytes)
self.decoder_placeholders.append(placeholder)
def decode(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.HEADERS
frame_data = buf.pull_bytes(buf.pull_uint_var())
decoder, actual_headers = self.decoder.feed_header(stream_id, frame_data)
placeholder.setdefault(decoder)
assert headers == actual_headers
yield self.send_encoder()
yield quic.SendQuicStreamData(
connection=self.conn,
stream_id=stream_id,
data=CallbackPlaceholder(decode),
end_stream=end_stream,
)
def receive_headers( def receive_headers(
self, self,
headers: Headers, headers: Headers,
@ -275,29 +305,16 @@ class FrameFactory:
end_stream=end_stream, end_stream=end_stream,
) )
def send_headers( def send_data(
self, self,
headers: Headers, data: bytes,
stream_id: int = 0, stream_id: int = 0,
end_stream: bool = False, end_stream: bool = False,
) -> Iterable[quic.SendQuicStreamData]: ) -> quic.SendQuicStreamData:
assert self.decoder_placeholder is None return quic.SendQuicStreamData(
placeholder = tutils.Placeholder(bytes) self.conn,
self.decoder_placeholder = placeholder
def decode(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.HEADERS
frame_data = buf.pull_bytes(buf.pull_uint_var())
decoder, headers = self.decoder.feed_header(stream_id, frame_data)
placeholder.setdefault(decoder)
assert headers == headers
yield self.send_encoder()
yield quic.SendQuicStreamData(
connection=self.conn,
stream_id=stream_id, stream_id=stream_id,
data=CallbackPlaceholder(decode), data=encode_frame(FrameType.DATA, data),
end_stream=end_stream, end_stream=end_stream,
) )
@ -314,13 +331,27 @@ class FrameFactory:
end_stream=end_stream, end_stream=end_stream,
) )
def send_server_init(self) -> Iterable[quic.SendQuicStreamData]: def send_init(self) -> Iterable[quic.SendQuicStreamData]:
yield self.send_stream_type(StreamType.CONTROL) yield self.send_stream_type(StreamType.CONTROL)
yield self.send_settings() yield self.send_settings()
if not self.is_client:
yield self.send_max_push_id() yield self.send_max_push_id()
yield self.send_stream_type(StreamType.QPACK_ENCODER) yield self.send_stream_type(StreamType.QPACK_ENCODER)
yield self.send_stream_type(StreamType.QPACK_DECODER) yield self.send_stream_type(StreamType.QPACK_DECODER)
def receive_init(self) -> Iterable[quic.QuicStreamDataReceived]:
yield self.receive_stream_type(StreamType.CONTROL)
yield self.receive_stream_type(StreamType.QPACK_ENCODER)
yield self.receive_stream_type(StreamType.QPACK_DECODER)
yield self.receive_settings()
@property
def is_done(self) -> bool:
return (
self.encoder_placeholder is None
and not self.decoder_placeholders
)
@pytest.fixture @pytest.fixture
def open_h3_server_conn(): def open_h3_server_conn():
@ -340,15 +371,9 @@ def start_h3_client(tctx: context.Context) -> tuple[tutils.Playbook, FrameFactor
cff = FrameFactory(conn=tctx.client, is_client=True) cff = FrameFactory(conn=tctx.client, is_client=True)
assert ( assert (
playbook playbook
<< cff.send_stream_type(StreamType.CONTROL) << cff.send_init()
<< cff.send_settings() >> cff.receive_init()
<< cff.send_stream_type(StreamType.QPACK_ENCODER)
<< cff.send_stream_type(StreamType.QPACK_DECODER)
>> cff.receive_stream_type(StreamType.CONTROL)
>> cff.receive_settings()
<< cff.send_encoder() << cff.send_encoder()
>> cff.receive_stream_type(StreamType.QPACK_ENCODER)
>> cff.receive_stream_type(StreamType.QPACK_DECODER)
>> cff.receive_encoder() >> cff.receive_encoder()
) )
return playbook, cff return playbook, cff
@ -366,14 +391,151 @@ def test_simple(tctx: context.Context):
sff = FrameFactory(server, is_client=False) sff = FrameFactory(server, is_client=False)
assert ( assert (
playbook playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True) >> cff.receive_headers(example_request_headers, end_stream=True)
<< http.HttpRequestHeadersHook(flow) << (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() << cff.send_decoder() # for receive_headers
>> tutils.reply(to=http.HttpRequestHeadersHook(flow)) >> tutils.reply(to=request)
<< http.HttpRequestHook(flow) << http.HttpRequestHook(flow)
>> tutils.reply() >> tutils.reply()
# request server
<< commands.OpenConnection(server) << commands.OpenConnection(server)
>> tutils.reply(None, side_effect=make_h3) >> tutils.reply(None, side_effect=make_h3)
<< sff.send_server_init() << sff.send_init()
<< sff.send_headers(example_request_headers, end_stream=True) << sff.send_headers(example_request_headers, end_stream=True)
>> sff.receive_init()
<< sff.send_encoder()
>> sff.receive_encoder()
>> sff.receive_decoder() # for send_headers
# response server
>> sff.receive_headers(example_response_headers)
<< (response := http.HttpResponseHeadersHook(flow))
<< sff.send_decoder() # for receive_headers
>> tutils.reply(to=response)
>> sff.receive_data(b"Hello, World!", end_stream=True)
<< http.HttpResponseHook(flow)
>> tutils.reply()
# response client
<< cff.send_headers(example_response_headers)
<< cff.send_data(b"Hello, World!")
<< cff.send_data(b"", end_stream=True)
>> cff.receive_decoder() # for send_headers
) )
assert cff.is_done and sff.is_done
assert flow().request.url == "http://example.com/"
assert flow().response.text == "Hello, World!"
@pytest.mark.parametrize("stream", [True, False])
def test_response_trailers(
tctx: context.Context,
open_h3_server_conn: connection.Server,
stream: bool,
):
playbook, cff = start_h3_client(tctx)
tctx.server = open_h3_server_conn
sff = FrameFactory(tctx.server, is_client=False)
def enable_streaming(flow: HTTPFlow):
flow.response.stream = stream
flow = tutils.Placeholder(HTTPFlow)
(
playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True)
<< (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() # for receive_headers
>> tutils.reply(to=request)
<< http.HttpRequestHook(flow)
>> tutils.reply()
# request server
<< sff.send_init()
<< sff.send_headers(example_request_headers, end_stream=True)
>> sff.receive_init()
<< sff.send_encoder()
>> sff.receive_encoder()
>> sff.receive_decoder() # for send_headers
# response server
>> sff.receive_headers(example_response_headers)
<< (response_headers := http.HttpResponseHeadersHook(flow))
<< sff.send_decoder() # for receive_headers
>> tutils.reply(to=response_headers, side_effect=enable_streaming)
)
if stream:
(
playbook
<< cff.send_headers(example_response_headers)
>> cff.receive_decoder() # for send_headers
>> sff.receive_data(b"Hello, World!")
<< cff.send_data(b"Hello, World!")
)
else:
playbook >> sff.receive_data(b"Hello, World!")
assert (
playbook
>> sff.receive_headers(example_response_trailers, end_stream=True)
<< (response := http.HttpResponseHook(flow))
<< sff.send_decoder() # for receive_headers
)
def modify_tailers(flow: HTTPFlow) -> None:
assert flow.response.trailers
del flow.response.trailers["resp-trailer-a"]
if stream:
assert (
playbook
>> tutils.reply(to=response, side_effect=modify_tailers)
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
>> cff.receive_decoder() # for send_headers
)
else:
assert (
playbook
>> tutils.reply(to=response, side_effect=modify_tailers)
<< cff.send_headers(example_response_headers)
<< cff.send_data(b"Hello, World!")
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
>> cff.receive_decoder() # for send_headers
>> cff.receive_decoder() # for send_headers
)
assert cff.is_done and sff.is_done
def test_upstream_error(tctx: context.Context):
playbook, cff = start_h3_client(tctx)
flow = tutils.Placeholder(HTTPFlow)
server = tutils.Placeholder(connection.Server)
err = tutils.Placeholder(bytes)
assert (
playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True)
<< (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() # for receive_headers
>> tutils.reply(to=request)
<< http.HttpRequestHook(flow)
>> tutils.reply()
# request server
<< commands.OpenConnection(server)
>> tutils.reply("oops server <> error")
<< http.HttpErrorHook(flow)
>> tutils.reply()
<< cff.send_headers([
(b":status", b"502"),
(b'server', version.MITMPROXY.encode()),
(b'content-type', b'text/html'),
])
<< quic.SendQuicStreamData(
tctx.client,
stream_id=0,
data=err,
end_stream=True,
)
>> cff.receive_decoder() # for send_headers
)
assert cff.is_done
data = decode_frame(FrameType.DATA, err())
assert b"502 Bad Gateway" in data
assert b"server &lt;&gt; error" in data