[quic] more h3 tests

This commit is contained in:
Manuel Meitinger 2022-11-07 01:22:59 +01:00
parent a2d962b09a
commit 201f03082a
1 changed files with 229 additions and 67 deletions

View File

@ -4,9 +4,18 @@ import pytest
import pylsqpack
from aioquic._buffer import Buffer
from aioquic.h3.connection import FrameType, StreamType, Headers, Setting, encode_frame, encode_uint_var, encode_settings, parse_settings
from aioquic.h3.connection import (
FrameType,
Headers,
Setting,
StreamType,
encode_frame,
encode_uint_var,
encode_settings,
parse_settings,
)
from mitmproxy import connection
from mitmproxy import connection, version
from mitmproxy.http import HTTPFlow
from mitmproxy.proxy import commands, context, layers
from mitmproxy.proxy.layers import http, quic
@ -20,6 +29,16 @@ example_request_headers = [
(b":authority", b"example.com"),
]
example_response_headers = [(b":status", b"200")]
example_response_trailers = [(b"resp-trailer-a", b"a"), (b"resp-trailer-b", b"b")]
def decode_frame(frame_type: int, frame_data: bytes) -> bytes:
buf = Buffer(data=frame_data)
assert buf.pull_uint_var() == frame_type
return buf.pull_bytes(buf.pull_uint_var())
class CallbackPlaceholder(tutils._Placeholder[bytes]):
"""Data placeholder that invokes a callback once its bytes get set."""
@ -77,7 +96,7 @@ class FrameFactory:
max_table_capacity=4096,
blocked_streams=16,
)
self.decoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
self.decoder_placeholders: list[tutils.Placeholder(bytes)] = []
self.encoder = pylsqpack.Encoder()
self.encoder_placeholder: Optional[tutils.Placeholder(bytes)] = None
self.peer_stream_id: dict[StreamType, int] = {}
@ -137,6 +156,21 @@ class FrameFactory:
end_stream=False,
)
def send_max_push_id(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
self.max_push_id = buf.pull_uint_var()
assert buf.eof()
return quic.SendQuicStreamData(
connection=self.conn,
stream_id=self.peer_stream_id[StreamType.CONTROL],
data=CallbackPlaceholder(cb),
end_stream=False,
)
def send_settings(self) -> quic.SendQuicStreamData:
assert self.encoder_placeholder is None
placeholder = tutils.Placeholder(bytes)
@ -158,21 +192,6 @@ class FrameFactory:
end_stream=False,
)
def send_max_push_id(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.MAX_PUSH_ID
buf = Buffer(data=buf.pull_bytes(buf.pull_uint_var()))
self.max_push_id = buf.pull_uint_var()
assert buf.eof()
return quic.SendQuicStreamData(
connection=self.conn,
stream_id=self.peer_stream_id[StreamType.CONTROL],
data=CallbackPlaceholder(cb),
end_stream=False,
)
def receive_settings(
self,
settings: dict[int, int] = {
@ -213,19 +232,6 @@ class FrameFactory:
end_stream=False,
)
def send_data(
self,
data: bytes,
stream_id: int = 0,
end_stream: bool = False,
) -> quic.SendQuicStreamData:
return quic.SendQuicStreamData(
self.conn,
stream_id=stream_id,
data=encode_frame(FrameType.DATA, data),
end_stream=end_stream,
)
def send_decoder(self) -> quic.SendQuicStreamData:
def cb(data: bytes) -> None:
self.encoder.feed_decoder(data)
@ -238,9 +244,8 @@ class FrameFactory:
)
def receive_decoder(self) -> quic.QuicStreamDataReceived:
assert self.decoder_placeholder is not None
placeholder = self.decoder_placeholder
self.decoder_placeholder = None
assert self.decoder_placeholders
placeholder = self.decoder_placeholders.pop(0)
return quic.QuicStreamDataReceived(
self.conn,
@ -249,6 +254,31 @@ class FrameFactory:
end_stream=False,
)
def send_headers(
self,
headers: Headers,
stream_id: int = 0,
end_stream: bool = False,
) -> Iterable[quic.SendQuicStreamData]:
placeholder = tutils.Placeholder(bytes)
self.decoder_placeholders.append(placeholder)
def decode(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.HEADERS
frame_data = buf.pull_bytes(buf.pull_uint_var())
decoder, actual_headers = self.decoder.feed_header(stream_id, frame_data)
placeholder.setdefault(decoder)
assert headers == actual_headers
yield self.send_encoder()
yield quic.SendQuicStreamData(
connection=self.conn,
stream_id=stream_id,
data=CallbackPlaceholder(decode),
end_stream=end_stream,
)
def receive_headers(
self,
headers: Headers,
@ -275,29 +305,16 @@ class FrameFactory:
end_stream=end_stream,
)
def send_headers(
def send_data(
self,
headers: Headers,
data: bytes,
stream_id: int = 0,
end_stream: bool = False,
) -> Iterable[quic.SendQuicStreamData]:
assert self.decoder_placeholder is None
placeholder = tutils.Placeholder(bytes)
self.decoder_placeholder = placeholder
def decode(data: bytes) -> None:
buf = Buffer(data=data)
assert buf.pull_uint_var() == FrameType.HEADERS
frame_data = buf.pull_bytes(buf.pull_uint_var())
decoder, headers = self.decoder.feed_header(stream_id, frame_data)
placeholder.setdefault(decoder)
assert headers == headers
yield self.send_encoder()
yield quic.SendQuicStreamData(
connection=self.conn,
) -> quic.SendQuicStreamData:
return quic.SendQuicStreamData(
self.conn,
stream_id=stream_id,
data=CallbackPlaceholder(decode),
data=encode_frame(FrameType.DATA, data),
end_stream=end_stream,
)
@ -314,13 +331,27 @@ class FrameFactory:
end_stream=end_stream,
)
def send_server_init(self) -> Iterable[quic.SendQuicStreamData]:
def send_init(self) -> Iterable[quic.SendQuicStreamData]:
yield self.send_stream_type(StreamType.CONTROL)
yield self.send_settings()
yield self.send_max_push_id()
if not self.is_client:
yield self.send_max_push_id()
yield self.send_stream_type(StreamType.QPACK_ENCODER)
yield self.send_stream_type(StreamType.QPACK_DECODER)
def receive_init(self) -> Iterable[quic.QuicStreamDataReceived]:
yield self.receive_stream_type(StreamType.CONTROL)
yield self.receive_stream_type(StreamType.QPACK_ENCODER)
yield self.receive_stream_type(StreamType.QPACK_DECODER)
yield self.receive_settings()
@property
def is_done(self) -> bool:
return (
self.encoder_placeholder is None
and not self.decoder_placeholders
)
@pytest.fixture
def open_h3_server_conn():
@ -340,15 +371,9 @@ def start_h3_client(tctx: context.Context) -> tuple[tutils.Playbook, FrameFactor
cff = FrameFactory(conn=tctx.client, is_client=True)
assert (
playbook
<< cff.send_stream_type(StreamType.CONTROL)
<< cff.send_settings()
<< cff.send_stream_type(StreamType.QPACK_ENCODER)
<< cff.send_stream_type(StreamType.QPACK_DECODER)
>> cff.receive_stream_type(StreamType.CONTROL)
>> cff.receive_settings()
<< cff.send_init()
>> cff.receive_init()
<< cff.send_encoder()
>> cff.receive_stream_type(StreamType.QPACK_ENCODER)
>> cff.receive_stream_type(StreamType.QPACK_DECODER)
>> cff.receive_encoder()
)
return playbook, cff
@ -366,14 +391,151 @@ def test_simple(tctx: context.Context):
sff = FrameFactory(server, is_client=False)
assert (
playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True)
<< http.HttpRequestHeadersHook(flow)
<< cff.send_decoder()
>> tutils.reply(to=http.HttpRequestHeadersHook(flow))
<< (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() # for receive_headers
>> tutils.reply(to=request)
<< http.HttpRequestHook(flow)
>> tutils.reply()
# request server
<< commands.OpenConnection(server)
>> tutils.reply(None, side_effect=make_h3)
<< sff.send_server_init()
<< sff.send_init()
<< sff.send_headers(example_request_headers, end_stream=True)
>> sff.receive_init()
<< sff.send_encoder()
>> sff.receive_encoder()
>> sff.receive_decoder() # for send_headers
# response server
>> sff.receive_headers(example_response_headers)
<< (response := http.HttpResponseHeadersHook(flow))
<< sff.send_decoder() # for receive_headers
>> tutils.reply(to=response)
>> sff.receive_data(b"Hello, World!", end_stream=True)
<< http.HttpResponseHook(flow)
>> tutils.reply()
# response client
<< cff.send_headers(example_response_headers)
<< cff.send_data(b"Hello, World!")
<< cff.send_data(b"", end_stream=True)
>> cff.receive_decoder() # for send_headers
)
assert cff.is_done and sff.is_done
assert flow().request.url == "http://example.com/"
assert flow().response.text == "Hello, World!"
@pytest.mark.parametrize("stream", [True, False])
def test_response_trailers(
tctx: context.Context,
open_h3_server_conn: connection.Server,
stream: bool,
):
playbook, cff = start_h3_client(tctx)
tctx.server = open_h3_server_conn
sff = FrameFactory(tctx.server, is_client=False)
def enable_streaming(flow: HTTPFlow):
flow.response.stream = stream
flow = tutils.Placeholder(HTTPFlow)
(
playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True)
<< (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() # for receive_headers
>> tutils.reply(to=request)
<< http.HttpRequestHook(flow)
>> tutils.reply()
# request server
<< sff.send_init()
<< sff.send_headers(example_request_headers, end_stream=True)
>> sff.receive_init()
<< sff.send_encoder()
>> sff.receive_encoder()
>> sff.receive_decoder() # for send_headers
# response server
>> sff.receive_headers(example_response_headers)
<< (response_headers := http.HttpResponseHeadersHook(flow))
<< sff.send_decoder() # for receive_headers
>> tutils.reply(to=response_headers, side_effect=enable_streaming)
)
if stream:
(
playbook
<< cff.send_headers(example_response_headers)
>> cff.receive_decoder() # for send_headers
>> sff.receive_data(b"Hello, World!")
<< cff.send_data(b"Hello, World!")
)
else:
playbook >> sff.receive_data(b"Hello, World!")
assert (
playbook
>> sff.receive_headers(example_response_trailers, end_stream=True)
<< (response := http.HttpResponseHook(flow))
<< sff.send_decoder() # for receive_headers
)
def modify_tailers(flow: HTTPFlow) -> None:
assert flow.response.trailers
del flow.response.trailers["resp-trailer-a"]
if stream:
assert (
playbook
>> tutils.reply(to=response, side_effect=modify_tailers)
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
>> cff.receive_decoder() # for send_headers
)
else:
assert (
playbook
>> tutils.reply(to=response, side_effect=modify_tailers)
<< cff.send_headers(example_response_headers)
<< cff.send_data(b"Hello, World!")
<< cff.send_headers(example_response_trailers[1:], end_stream=True)
>> cff.receive_decoder() # for send_headers
>> cff.receive_decoder() # for send_headers
)
assert cff.is_done and sff.is_done
def test_upstream_error(tctx: context.Context):
playbook, cff = start_h3_client(tctx)
flow = tutils.Placeholder(HTTPFlow)
server = tutils.Placeholder(connection.Server)
err = tutils.Placeholder(bytes)
assert (
playbook
# request client
>> cff.receive_headers(example_request_headers, end_stream=True)
<< (request := http.HttpRequestHeadersHook(flow))
<< cff.send_decoder() # for receive_headers
>> tutils.reply(to=request)
<< http.HttpRequestHook(flow)
>> tutils.reply()
# request server
<< commands.OpenConnection(server)
>> tutils.reply("oops server <> error")
<< http.HttpErrorHook(flow)
>> tutils.reply()
<< cff.send_headers([
(b":status", b"502"),
(b'server', version.MITMPROXY.encode()),
(b'content-type', b'text/html'),
])
<< quic.SendQuicStreamData(
tctx.client,
stream_id=0,
data=err,
end_stream=True,
)
>> cff.receive_decoder() # for send_headers
)
assert cff.is_done
data = decode_frame(FrameType.DATA, err())
assert b"502 Bad Gateway" in data
assert b"server &lt;&gt; error" in data