[quic] full quic.py coverage

This commit is contained in:
Manuel Meitinger 2022-10-29 20:11:53 +02:00
parent 8a78191a5f
commit 26b2545dc2
2 changed files with 267 additions and 30 deletions

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from logging import DEBUG, ERROR, WARNING
from ssl import VerifyMode
import time
from typing import TYPE_CHECKING, Callable
from typing import Callable
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.h3.connection import ErrorCode as H3ErrorCode
@ -41,9 +41,6 @@ from mitmproxy.proxy.layers.tls import (
from mitmproxy.proxy.layers.udp import UDPLayer
from mitmproxy.tls import ClientHello, ClientHelloData, TlsData
if TYPE_CHECKING:
from mitmproxy.proxy.server import ConnectionHandler
@dataclass
class QuicTlsSettings:
@ -423,7 +420,7 @@ class QuicStreamLayer(layer.Layer):
elif isinstance(child_layer, tunnel.TunnelLayer):
child_layer = child_layer.child_layer
else:
break
break # pragma: no cover
if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow:
child_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(self._client_stream_id)
child_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(self._client_stream_id) else "server"
@ -701,7 +698,6 @@ class QuicLayer(tunnel.TunnelLayer):
self.child_layer = layer.NextLayer(self.context, ask_on_start=True)
self._time = time or asyncio.get_running_loop().time
self._wakeup_commands: dict[commands.RequestWakeup, float] = dict()
self._routes: dict[connection.Address, ConnectionHandler | None] = dict()
conn.tls = True
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
@ -826,8 +822,7 @@ class QuicLayer(tunnel.TunnelLayer):
all_certs: list[x509.Certificate] = []
if self.quic.tls._peer_certificate:
all_certs.append(self.quic.tls._peer_certificate)
if self.quic.tls._peer_certificate_chain:
all_certs.extend(self.quic.tls._peer_certificate_chain)
all_certs.extend(self.quic.tls._peer_certificate_chain)
# set the connection's TLS properties
self.conn.timestamp_tls_setup = time.time()
@ -983,7 +978,7 @@ class ServerQuicLayer(QuicLayer):
class ClientQuicLayer(QuicLayer):
"""
This layer establishes QUIC on a single client connection or roams to another connection.
This layer establishes QUIC on a single client connection.
"""
server_tls_available: bool
@ -1091,6 +1086,7 @@ class ClientQuicLayer(QuicLayer):
# start the client QUIC connection
yield from self.start_tls(header.destination_cid)
# XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately
if not self.conn.connected:
return False, "connection closed early"

View File

@ -1,4 +1,4 @@
from logging import DEBUG, WARNING
from logging import DEBUG, ERROR, WARNING
import ssl
import time
from aioquic.buffer import Buffer as QuicBuffer
@ -35,16 +35,12 @@ def tctx() -> context.Context:
)
class InvalidStreamEvent(quic.QuicStreamEvent):
pass
class DummyLayer(layer.Layer):
child_layer: Optional[layer.Layer]
class InvalidEvent(events.Event):
pass
class InvalidConnectionCommand(commands.ConnectionCommand):
pass
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.child_layer
return self.child_layer.handle_event(event)
class TlsEchoLayer(tutils.EchoLayer):
@ -58,10 +54,26 @@ class TlsEchoLayer(tutils.EchoLayer):
yield commands.SendData(
event.connection, f"open-connection failed: {err}".encode()
)
elif isinstance(event, events.DataReceived) and event.data == b"close-connection":
yield commands.CloseConnection(event.connection)
elif isinstance(event, events.DataReceived) and event.data == b"close-connection-error":
yield quic.CloseQuicConnection(event.connection, ~0, None, "error")
elif isinstance(event, events.DataReceived) and event.data == b"stop-stream":
yield quic.StopQuicStream(event.connection, 24, 123)
elif isinstance(event, events.DataReceived) and event.data == b"invalid-command":
class InvalidConnectionCommand(commands.ConnectionCommand):
pass
yield InvalidConnectionCommand(event.connection)
elif isinstance(event, events.DataReceived) and event.data == b"invalid-stream-command":
class InvalidStreamCommand(quic.QuicStreamCommand):
pass
yield InvalidStreamCommand(event.connection, 42)
elif isinstance(event, quic.QuicConnectionClosed):
self.closed = event
elif isinstance(event, quic.QuicStreamDataReceived):
yield quic.SendQuicStreamData(event.connection, event.stream_id, event.data, event.end_stream)
elif isinstance(event, quic.QuicStreamReset):
yield quic.ResetQuicStream(event.connection, event.stream_id, event.error_code)
else:
yield from super()._handle_event(event)
@ -124,7 +136,6 @@ def test_secrets_logger(value: str):
class TestParseClientHello:
def test_input(self):
assert quic.quic_parse_client_hello(client_hello).sni == "example.com"
with pytest.raises(ValueError):
@ -154,13 +165,9 @@ class TestParseClientHello:
with pytest.raises(ValueError, match="Conn err"):
quic.quic_parse_client_hello(client_hello)
def test_no_return(self, monkeypatch):
def do_nothing(self, data, addr, now):
pass
monkeypatch.setattr(QuicConnection, "receive_datagram", do_nothing)
def test_no_return(self):
with pytest.raises(ValueError, match="No ClientHello"):
quic.quic_parse_client_hello(client_hello)
quic.quic_parse_client_hello(client_hello[0:1200] + b'\x00' + client_hello[1200:])
class TestQuicStreamLayer:
@ -320,6 +327,8 @@ class TestRawQuicLayer:
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected stream event"):
class InvalidStreamEvent(quic.QuicStreamEvent):
pass
playbook >> InvalidStreamEvent(tctx.client, 0)
assert playbook
@ -331,6 +340,8 @@ class TestRawQuicLayer:
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected event"):
class InvalidEvent(events.Event):
pass
playbook >> InvalidEvent()
assert playbook
@ -383,6 +394,125 @@ class TestRawQuicLayer:
assert playbook
class MockQuic(QuicConnection):
def __init__(self, event) -> None:
super().__init__(configuration=QuicConfiguration(is_client=True))
self.event = event
def next_event(self):
event = self.event
self.event = None
return event
def datagrams_to_send(self, now: float):
return []
def get_timer(self):
return None
def make_mock_quic(
tctx: context.Context,
event: Optional[quic_events.QuicEvent] = None,
established: bool = True
) -> tuple[tutils.Playbook, MockQuic]:
tctx.client.state = connection.ConnectionState.CLOSED
quic_layer = quic.QuicLayer(tctx, tctx.client, time=lambda: 0)
quic_layer.child_layer = TlsEchoLayer(tctx)
mock = MockQuic(event)
quic_layer.quic = mock
quic_layer.tunnel_state = (
tls.tunnel.TunnelState.OPEN
if established else
tls.tunnel.TunnelState.ESTABLISHING
)
return tutils.Playbook(quic_layer), mock
class TestQuicLayer:
@pytest.mark.parametrize("established", [True, False])
def test_invalid_event(self, tctx: context.Context, established: bool):
class InvalidEvent(quic_events.QuicEvent):
pass
playbook, conn = make_mock_quic(
tctx, event=InvalidEvent(), established=established
)
with pytest.raises(AssertionError, match="Unexpected event"):
assert (
playbook
>> events.DataReceived(tctx.client, b"")
)
def test_invalid_stream_command(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.DatagramFrameReceived(b"invalid-stream-command")
)
with pytest.raises(AssertionError, match="Unexpected stream command"):
assert (playbook >> events.DataReceived(tctx.client, b""))
def test_close(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.DatagramFrameReceived(b"close-connection")
)
assert not conn._close_event
assert (
playbook
>> events.DataReceived(tctx.client, b"")
<< commands.CloseConnection(tctx.client)
)
assert conn._close_event
assert conn._close_event.error_code == 0
def test_close_error(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.DatagramFrameReceived(b"close-connection-error")
)
assert not conn._close_event
assert (
playbook
>> events.DataReceived(tctx.client, b"")
<< quic.CloseQuicConnection(tctx.client, ~0, None, "error")
)
assert conn._close_event
assert conn._close_event.error_code == ~0
def test_datagram(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.DatagramFrameReceived(b"packet")
)
assert not conn._datagrams_pending
assert (playbook >> events.DataReceived(tctx.client, b""))
assert len(conn._datagrams_pending) == 1
assert conn._datagrams_pending[0] == b"packet"
def test_stream_data(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.StreamDataReceived(b"packet", False, 42)
)
assert 42 not in conn._streams
assert (playbook >> events.DataReceived(tctx.client, b""))
assert b"packet" == conn._streams[42].sender._buffer
def test_stream_reset(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.StreamReset(123, 42)
)
assert 42 not in conn._streams
assert (playbook >> events.DataReceived(tctx.client, b""))
assert conn._streams[42].sender.reset_pending
assert conn._streams[42].sender._reset_error_code == 123
def test_stream_stop(self, tctx: context.Context):
playbook, conn = make_mock_quic(
tctx, quic_events.DatagramFrameReceived(b"stop-stream")
)
assert 24 not in conn._streams
conn._get_or_create_stream_for_send(24)
assert (playbook >> events.DataReceived(tctx.client, b""))
assert conn._streams[24].receiver.stop_pending
assert conn._streams[24].receiver._stop_error_code == 123
class SSLTest:
"""Helper container for Python's builtin SSL object."""
@ -391,6 +521,7 @@ class SSLTest:
server_side: bool = False,
alpn: Optional[list[str]] = None,
sni: Optional[str] = "example.mitmproxy.org",
version: Optional[int] = None,
):
self.ctx = QuicConfiguration(
is_client=not server_side,
@ -420,6 +551,9 @@ class SSLTest:
self.ctx.server_name = None if server_side else sni
if version is not None:
self.ctx.supported_versions = [version]
self.now = 0.0
self.address = (sni, 443)
self.quic = None if server_side else QuicConnection(configuration=self.ctx)
@ -681,14 +815,13 @@ class TestServerTLS:
def make_client_tls_layer(
tctx: context.Context, **kwargs
) -> tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]:
tctx: context.Context, no_server: bool = False, **kwargs
) -> tuple[tutils.Playbook, quic.ClientQuicLayer, SSLTest]:
tssl_client = SSLTest(**kwargs)
tssl_client.quic.connect(tssl_client.address, 0)
# This is a bit contrived as the client layer expects a server layer as parent.
# We also set child layers manually to avoid NextLayer noise.
server_layer = quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now)
server_layer = DummyLayer(tctx) if no_server else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now)
client_layer = quic.ClientQuicLayer(tctx, time=lambda: tssl_client.now)
server_layer.child_layer = client_layer
playbook = tutils.Playbook(server_layer)
@ -701,6 +834,7 @@ def make_client_tls_layer(
tctx.server.sni = "example.mitmproxy.org"
# Start handshake.
tssl_client.quic.connect(tssl_client.address, now=tssl_client.now)
assert not tssl_client.handshake_completed()
return playbook, client_layer, tssl_client
@ -742,6 +876,16 @@ class TestClientTLS:
<< commands.SendData(other_server, b"plaintext")
)
# test the close log
tssl_client.now = tssl_client.now + 60
assert (
playbook
>> events.Wakeup(playbook.actual[16])
<< commands.Log(" >> Wakeup(command=RequestWakeup({'delay': 0.20000000000000004}))", DEBUG)
<< commands.Log(" [quic] close_notify Client(client:1234, state=open, tls) (reason=Idle timeout)", DEBUG)
<< commands.CloseConnection(tctx.client)
)
@pytest.mark.parametrize("server_state", ["open", "closed"])
def test_server_required(self, tctx: context.Context, server_state: Literal["open", "closed"]):
"""
@ -921,3 +1065,100 @@ class TestClientTLS:
)
assert not tctx.client.tls_established
assert tls_hook_data().conn.error
def test_server_unavailable_and_no_settings(self, tctx: context.Context):
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
def require_server_conn(client_hello: tls.ClientHelloData) -> None:
client_hello.establish_server_tls_first = True
assert (
playbook
>> events.DataReceived(tctx.client, tssl_client.read())
<< tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply(side_effect=require_server_conn)
<< commands.OpenConnection(tctx.server)
>> tutils.reply("I cannot open the server, Dave")
<< commands.Log(
f"Unable to establish QUIC connection with server (I cannot open the server, Dave). "
f"Trying to establish QUIC with client anyway. "
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
<< quic.QuicStartClientHook(tutils.Placeholder())
)
tctx.client.state = connection.ConnectionState.CLOSED
assert (
playbook
>> tutils.reply()
<< commands.Log(f"No QUIC context was provided, failing connection.", ERROR)
<< commands.CloseConnection(tctx.client)
<< commands.Log("Client QUIC handshake failed. connection closed early", WARNING)
<< tls.TlsFailedClientHook(tutils.Placeholder())
)
def test_no_server_tls(self, tctx: context.Context):
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, no_server=True)
def require_server_conn(client_hello: tls.ClientHelloData) -> None:
client_hello.establish_server_tls_first = True
assert (
playbook
>> events.DataReceived(tctx.client, tssl_client.read())
<< tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply(side_effect=require_server_conn)
<< commands.Log(
f"Unable to establish QUIC connection with server (No server QUIC available.). "
f"Trying to establish QUIC with client anyway. "
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
<< quic.QuicStartClientHook(tutils.Placeholder())
)
def test_version_negotiation(self, tctx: context.Context):
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, version=0)
assert (
playbook
>> events.DataReceived(tctx.client, tssl_client.read())
<< commands.SendData(tctx.client, tutils.Placeholder())
)
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
def test_non_init_clienthello(self, tctx: context.Context):
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
data = (
b'\xc2\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb\x00@a\x98\x19\xf95t\xad-\x1c\\a\xdd\x8c\xd0\x15F'
b'\xdf\xdc\x87cb\x1eu\xb0\x95*\xac\xa8\xf7a \xb8\nQ\xbd=\xf5x\xca\r\xe6\x8b\x05 w\x9f\xcd\x8d\xcb\xa0\x06\x1e \x8d.\x8f'
b'T\xda\x12et\xe4\x83\x93X<o\xad\xd5%&\x8f7\xa6>\x8aa\xd1\xb2\x18\xb6\xa7\xf50y\x9b\xc5T\xe1\x87\xdd\x9fqv\xb0\x90\xa7s'
b'\xee\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb@a*.\xa8j\x90\x1b\x1a\x7fZ\x04\x0b\\\xc7\x00\x03'
b'\xd7sC\xf8G\x84\x1e\xba\xcf\x08Z\xdd\x98+\xaa\x98J\xca\xe3\xb7u1\x89\x00\xdf\x8e\x16`\xd9^\xc0@i\x1a\x10\x99\r\xd8'
b'\x1dv3\xc6\xb8"\xb9\xa8F\x95K\x9a/\xbc\'\xd8\xd8\x94\x8f\xe7B/\x05\x9d\xfb\x80\xa9\xda@\xe6\xb0J\xfe\xe0\x0f\x02L}'
b'\xd9\xed\xd2L\xa7\xcf'
)
assert (
playbook
>> events.DataReceived(tctx.client, data)
<< commands.Log(f"Client QUIC handshake failed. Invalid handshake received, roaming not supported. ({data.hex()})", WARNING)
<< tls.TlsFailedClientHook(tutils.Placeholder())
)
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
def test_invalid_clienthello(self, tctx: context.Context):
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
data = client_hello[0:1200] + b'\x00' + client_hello[1200:]
assert (
playbook
>> events.DataReceived(tctx.client, data)
<< commands.Log(f"Client QUIC handshake failed. Cannot parse ClientHello: No ClientHello returned. ({data.hex()})", WARNING)
<< tls.TlsFailedClientHook(tutils.Placeholder())
)
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
def test_tls_reset(self, tctx: context.Context):
tctx.client.tls = True
tctx.client.sni = "some"
DummyLayer(tctx)
quic.ClientQuicLayer(tctx, time=lambda: 0)
assert tctx.client.sni is None