diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index eee197f14..46a0b1304 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from logging import DEBUG, ERROR, WARNING from ssl import VerifyMode import time -from typing import TYPE_CHECKING, Callable +from typing import Callable from aioquic.buffer import Buffer as QuicBuffer from aioquic.h3.connection import ErrorCode as H3ErrorCode @@ -41,9 +41,6 @@ from mitmproxy.proxy.layers.tls import ( from mitmproxy.proxy.layers.udp import UDPLayer from mitmproxy.tls import ClientHello, ClientHelloData, TlsData -if TYPE_CHECKING: - from mitmproxy.proxy.server import ConnectionHandler - @dataclass class QuicTlsSettings: @@ -423,7 +420,7 @@ class QuicStreamLayer(layer.Layer): elif isinstance(child_layer, tunnel.TunnelLayer): child_layer = child_layer.child_layer else: - break + break # pragma: no cover if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow: child_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(self._client_stream_id) child_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(self._client_stream_id) else "server" @@ -701,7 +698,6 @@ class QuicLayer(tunnel.TunnelLayer): self.child_layer = layer.NextLayer(self.context, ask_on_start=True) self._time = time or asyncio.get_running_loop().time self._wakeup_commands: dict[commands.RequestWakeup, float] = dict() - self._routes: dict[connection.Address, ConnectionHandler | None] = dict() conn.tls = True def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: @@ -826,8 +822,7 @@ class QuicLayer(tunnel.TunnelLayer): all_certs: list[x509.Certificate] = [] if self.quic.tls._peer_certificate: all_certs.append(self.quic.tls._peer_certificate) - if self.quic.tls._peer_certificate_chain: - all_certs.extend(self.quic.tls._peer_certificate_chain) + all_certs.extend(self.quic.tls._peer_certificate_chain) # set the connection's TLS properties self.conn.timestamp_tls_setup = time.time() @@ -983,7 +978,7 @@ class ServerQuicLayer(QuicLayer): class ClientQuicLayer(QuicLayer): """ - This layer establishes QUIC on a single client connection or roams to another connection. + This layer establishes QUIC on a single client connection. """ server_tls_available: bool @@ -1091,6 +1086,7 @@ class ClientQuicLayer(QuicLayer): # start the client QUIC connection yield from self.start_tls(header.destination_cid) + # XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately if not self.conn.connected: return False, "connection closed early" diff --git a/test/mitmproxy/proxy/layers/test_quic.py b/test/mitmproxy/proxy/layers/test_quic.py index 4f65a5a1c..f6fb30d7d 100644 --- a/test/mitmproxy/proxy/layers/test_quic.py +++ b/test/mitmproxy/proxy/layers/test_quic.py @@ -1,4 +1,4 @@ -from logging import DEBUG, WARNING +from logging import DEBUG, ERROR, WARNING import ssl import time from aioquic.buffer import Buffer as QuicBuffer @@ -35,16 +35,12 @@ def tctx() -> context.Context: ) -class InvalidStreamEvent(quic.QuicStreamEvent): - pass +class DummyLayer(layer.Layer): + child_layer: Optional[layer.Layer] - -class InvalidEvent(events.Event): - pass - - -class InvalidConnectionCommand(commands.ConnectionCommand): - pass + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + assert self.child_layer + return self.child_layer.handle_event(event) class TlsEchoLayer(tutils.EchoLayer): @@ -58,10 +54,26 @@ class TlsEchoLayer(tutils.EchoLayer): yield commands.SendData( event.connection, f"open-connection failed: {err}".encode() ) + elif isinstance(event, events.DataReceived) and event.data == b"close-connection": + yield commands.CloseConnection(event.connection) + elif isinstance(event, events.DataReceived) and event.data == b"close-connection-error": + yield quic.CloseQuicConnection(event.connection, ~0, None, "error") + elif isinstance(event, events.DataReceived) and event.data == b"stop-stream": + yield quic.StopQuicStream(event.connection, 24, 123) elif isinstance(event, events.DataReceived) and event.data == b"invalid-command": + class InvalidConnectionCommand(commands.ConnectionCommand): + pass yield InvalidConnectionCommand(event.connection) + elif isinstance(event, events.DataReceived) and event.data == b"invalid-stream-command": + class InvalidStreamCommand(quic.QuicStreamCommand): + pass + yield InvalidStreamCommand(event.connection, 42) elif isinstance(event, quic.QuicConnectionClosed): self.closed = event + elif isinstance(event, quic.QuicStreamDataReceived): + yield quic.SendQuicStreamData(event.connection, event.stream_id, event.data, event.end_stream) + elif isinstance(event, quic.QuicStreamReset): + yield quic.ResetQuicStream(event.connection, event.stream_id, event.error_code) else: yield from super()._handle_event(event) @@ -124,7 +136,6 @@ def test_secrets_logger(value: str): class TestParseClientHello: - def test_input(self): assert quic.quic_parse_client_hello(client_hello).sni == "example.com" with pytest.raises(ValueError): @@ -154,13 +165,9 @@ class TestParseClientHello: with pytest.raises(ValueError, match="Conn err"): quic.quic_parse_client_hello(client_hello) - def test_no_return(self, monkeypatch): - def do_nothing(self, data, addr, now): - pass - - monkeypatch.setattr(QuicConnection, "receive_datagram", do_nothing) + def test_no_return(self): with pytest.raises(ValueError, match="No ClientHello"): - quic.quic_parse_client_hello(client_hello) + quic.quic_parse_client_hello(client_hello[0:1200] + b'\x00' + client_hello[1200:]) class TestQuicStreamLayer: @@ -320,6 +327,8 @@ class TestRawQuicLayer: >> tutils.reply(None) ) with pytest.raises(AssertionError, match="Unexpected stream event"): + class InvalidStreamEvent(quic.QuicStreamEvent): + pass playbook >> InvalidStreamEvent(tctx.client, 0) assert playbook @@ -331,6 +340,8 @@ class TestRawQuicLayer: >> tutils.reply(None) ) with pytest.raises(AssertionError, match="Unexpected event"): + class InvalidEvent(events.Event): + pass playbook >> InvalidEvent() assert playbook @@ -383,6 +394,125 @@ class TestRawQuicLayer: assert playbook +class MockQuic(QuicConnection): + def __init__(self, event) -> None: + super().__init__(configuration=QuicConfiguration(is_client=True)) + self.event = event + + def next_event(self): + event = self.event + self.event = None + return event + + def datagrams_to_send(self, now: float): + return [] + + def get_timer(self): + return None + + +def make_mock_quic( + tctx: context.Context, + event: Optional[quic_events.QuicEvent] = None, + established: bool = True +) -> tuple[tutils.Playbook, MockQuic]: + tctx.client.state = connection.ConnectionState.CLOSED + quic_layer = quic.QuicLayer(tctx, tctx.client, time=lambda: 0) + quic_layer.child_layer = TlsEchoLayer(tctx) + mock = MockQuic(event) + quic_layer.quic = mock + quic_layer.tunnel_state = ( + tls.tunnel.TunnelState.OPEN + if established else + tls.tunnel.TunnelState.ESTABLISHING + ) + return tutils.Playbook(quic_layer), mock + + +class TestQuicLayer: + @pytest.mark.parametrize("established", [True, False]) + def test_invalid_event(self, tctx: context.Context, established: bool): + class InvalidEvent(quic_events.QuicEvent): + pass + playbook, conn = make_mock_quic( + tctx, event=InvalidEvent(), established=established + ) + with pytest.raises(AssertionError, match="Unexpected event"): + assert ( + playbook + >> events.DataReceived(tctx.client, b"") + ) + + def test_invalid_stream_command(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.DatagramFrameReceived(b"invalid-stream-command") + ) + with pytest.raises(AssertionError, match="Unexpected stream command"): + assert (playbook >> events.DataReceived(tctx.client, b"")) + + def test_close(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.DatagramFrameReceived(b"close-connection") + ) + assert not conn._close_event + assert ( + playbook + >> events.DataReceived(tctx.client, b"") + << commands.CloseConnection(tctx.client) + ) + assert conn._close_event + assert conn._close_event.error_code == 0 + + def test_close_error(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.DatagramFrameReceived(b"close-connection-error") + ) + assert not conn._close_event + assert ( + playbook + >> events.DataReceived(tctx.client, b"") + << quic.CloseQuicConnection(tctx.client, ~0, None, "error") + ) + assert conn._close_event + assert conn._close_event.error_code == ~0 + + def test_datagram(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.DatagramFrameReceived(b"packet") + ) + assert not conn._datagrams_pending + assert (playbook >> events.DataReceived(tctx.client, b"")) + assert len(conn._datagrams_pending) == 1 + assert conn._datagrams_pending[0] == b"packet" + + def test_stream_data(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.StreamDataReceived(b"packet", False, 42) + ) + assert 42 not in conn._streams + assert (playbook >> events.DataReceived(tctx.client, b"")) + assert b"packet" == conn._streams[42].sender._buffer + + def test_stream_reset(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.StreamReset(123, 42) + ) + assert 42 not in conn._streams + assert (playbook >> events.DataReceived(tctx.client, b"")) + assert conn._streams[42].sender.reset_pending + assert conn._streams[42].sender._reset_error_code == 123 + + def test_stream_stop(self, tctx: context.Context): + playbook, conn = make_mock_quic( + tctx, quic_events.DatagramFrameReceived(b"stop-stream") + ) + assert 24 not in conn._streams + conn._get_or_create_stream_for_send(24) + assert (playbook >> events.DataReceived(tctx.client, b"")) + assert conn._streams[24].receiver.stop_pending + assert conn._streams[24].receiver._stop_error_code == 123 + + class SSLTest: """Helper container for Python's builtin SSL object.""" @@ -391,6 +521,7 @@ class SSLTest: server_side: bool = False, alpn: Optional[list[str]] = None, sni: Optional[str] = "example.mitmproxy.org", + version: Optional[int] = None, ): self.ctx = QuicConfiguration( is_client=not server_side, @@ -420,6 +551,9 @@ class SSLTest: self.ctx.server_name = None if server_side else sni + if version is not None: + self.ctx.supported_versions = [version] + self.now = 0.0 self.address = (sni, 443) self.quic = None if server_side else QuicConnection(configuration=self.ctx) @@ -681,14 +815,13 @@ class TestServerTLS: def make_client_tls_layer( - tctx: context.Context, **kwargs -) -> tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]: + tctx: context.Context, no_server: bool = False, **kwargs +) -> tuple[tutils.Playbook, quic.ClientQuicLayer, SSLTest]: tssl_client = SSLTest(**kwargs) - tssl_client.quic.connect(tssl_client.address, 0) # This is a bit contrived as the client layer expects a server layer as parent. # We also set child layers manually to avoid NextLayer noise. - server_layer = quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now) + server_layer = DummyLayer(tctx) if no_server else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now) client_layer = quic.ClientQuicLayer(tctx, time=lambda: tssl_client.now) server_layer.child_layer = client_layer playbook = tutils.Playbook(server_layer) @@ -701,6 +834,7 @@ def make_client_tls_layer( tctx.server.sni = "example.mitmproxy.org" # Start handshake. + tssl_client.quic.connect(tssl_client.address, now=tssl_client.now) assert not tssl_client.handshake_completed() return playbook, client_layer, tssl_client @@ -742,6 +876,16 @@ class TestClientTLS: << commands.SendData(other_server, b"plaintext") ) + # test the close log + tssl_client.now = tssl_client.now + 60 + assert ( + playbook + >> events.Wakeup(playbook.actual[16]) + << commands.Log(" >> Wakeup(command=RequestWakeup({'delay': 0.20000000000000004}))", DEBUG) + << commands.Log(" [quic] close_notify Client(client:1234, state=open, tls) (reason=Idle timeout)", DEBUG) + << commands.CloseConnection(tctx.client) + ) + @pytest.mark.parametrize("server_state", ["open", "closed"]) def test_server_required(self, tctx: context.Context, server_state: Literal["open", "closed"]): """ @@ -921,3 +1065,100 @@ class TestClientTLS: ) assert not tctx.client.tls_established assert tls_hook_data().conn.error + + def test_server_unavailable_and_no_settings(self, tctx: context.Context): + playbook, client_layer, tssl_client = make_client_tls_layer(tctx) + + def require_server_conn(client_hello: tls.ClientHelloData) -> None: + client_hello.establish_server_tls_first = True + + assert ( + playbook + >> events.DataReceived(tctx.client, tssl_client.read()) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply(side_effect=require_server_conn) + << commands.OpenConnection(tctx.server) + >> tutils.reply("I cannot open the server, Dave") + << commands.Log( + f"Unable to establish QUIC connection with server (I cannot open the server, Dave). " + f"Trying to establish QUIC with client anyway. " + f"If you plan to redirect requests away from this server, " + f"consider setting `connection_strategy` to `lazy` to suppress early connections." + ) + << quic.QuicStartClientHook(tutils.Placeholder()) + ) + tctx.client.state = connection.ConnectionState.CLOSED + assert ( + playbook + >> tutils.reply() + << commands.Log(f"No QUIC context was provided, failing connection.", ERROR) + << commands.CloseConnection(tctx.client) + << commands.Log("Client QUIC handshake failed. connection closed early", WARNING) + << tls.TlsFailedClientHook(tutils.Placeholder()) + ) + + def test_no_server_tls(self, tctx: context.Context): + playbook, client_layer, tssl_client = make_client_tls_layer(tctx, no_server=True) + + def require_server_conn(client_hello: tls.ClientHelloData) -> None: + client_hello.establish_server_tls_first = True + + assert ( + playbook + >> events.DataReceived(tctx.client, tssl_client.read()) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply(side_effect=require_server_conn) + << commands.Log( + f"Unable to establish QUIC connection with server (No server QUIC available.). " + f"Trying to establish QUIC with client anyway. " + f"If you plan to redirect requests away from this server, " + f"consider setting `connection_strategy` to `lazy` to suppress early connections." + ) + << quic.QuicStartClientHook(tutils.Placeholder()) + ) + + def test_version_negotiation(self, tctx: context.Context): + playbook, client_layer, tssl_client = make_client_tls_layer(tctx, version=0) + assert ( + playbook + >> events.DataReceived(tctx.client, tssl_client.read()) + << commands.SendData(tctx.client, tutils.Placeholder()) + ) + assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING + + def test_non_init_clienthello(self, tctx: context.Context): + playbook, client_layer, tssl_client = make_client_tls_layer(tctx) + data = ( + b'\xc2\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb\x00@a\x98\x19\xf95t\xad-\x1c\\a\xdd\x8c\xd0\x15F' + b'\xdf\xdc\x87cb\x1eu\xb0\x95*\xac\xa8\xf7a \xb8\nQ\xbd=\xf5x\xca\r\xe6\x8b\x05 w\x9f\xcd\x8d\xcb\xa0\x06\x1e \x8d.\x8f' + b'T\xda\x12et\xe4\x83\x93X\x8aa\xd1\xb2\x18\xb6\xa7\xf50y\x9b\xc5T\xe1\x87\xdd\x9fqv\xb0\x90\xa7s' + b'\xee\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb@a*.\xa8j\x90\x1b\x1a\x7fZ\x04\x0b\\\xc7\x00\x03' + b'\xd7sC\xf8G\x84\x1e\xba\xcf\x08Z\xdd\x98+\xaa\x98J\xca\xe3\xb7u1\x89\x00\xdf\x8e\x16`\xd9^\xc0@i\x1a\x10\x99\r\xd8' + b'\x1dv3\xc6\xb8"\xb9\xa8F\x95K\x9a/\xbc\'\xd8\xd8\x94\x8f\xe7B/\x05\x9d\xfb\x80\xa9\xda@\xe6\xb0J\xfe\xe0\x0f\x02L}' + b'\xd9\xed\xd2L\xa7\xcf' + ) + assert ( + playbook + >> events.DataReceived(tctx.client, data) + << commands.Log(f"Client QUIC handshake failed. Invalid handshake received, roaming not supported. ({data.hex()})", WARNING) + << tls.TlsFailedClientHook(tutils.Placeholder()) + ) + assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING + + def test_invalid_clienthello(self, tctx: context.Context): + playbook, client_layer, tssl_client = make_client_tls_layer(tctx) + data = client_hello[0:1200] + b'\x00' + client_hello[1200:] + assert ( + playbook + >> events.DataReceived(tctx.client, data) + << commands.Log(f"Client QUIC handshake failed. Cannot parse ClientHello: No ClientHello returned. ({data.hex()})", WARNING) + << tls.TlsFailedClientHook(tutils.Placeholder()) + ) + assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING + + def test_tls_reset(self, tctx: context.Context): + tctx.client.tls = True + tctx.client.sni = "some" + DummyLayer(tctx) + quic.ClientQuicLayer(tctx, time=lambda: 0) + assert tctx.client.sni is None