[quic] full quic.py coverage
This commit is contained in:
parent
8a78191a5f
commit
26b2545dc2
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
|||
from logging import DEBUG, ERROR, WARNING
|
||||
from ssl import VerifyMode
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import Callable
|
||||
|
||||
from aioquic.buffer import Buffer as QuicBuffer
|
||||
from aioquic.h3.connection import ErrorCode as H3ErrorCode
|
||||
|
@ -41,9 +41,6 @@ from mitmproxy.proxy.layers.tls import (
|
|||
from mitmproxy.proxy.layers.udp import UDPLayer
|
||||
from mitmproxy.tls import ClientHello, ClientHelloData, TlsData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mitmproxy.proxy.server import ConnectionHandler
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTlsSettings:
|
||||
|
@ -423,7 +420,7 @@ class QuicStreamLayer(layer.Layer):
|
|||
elif isinstance(child_layer, tunnel.TunnelLayer):
|
||||
child_layer = child_layer.child_layer
|
||||
else:
|
||||
break
|
||||
break # pragma: no cover
|
||||
if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow:
|
||||
child_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(self._client_stream_id)
|
||||
child_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(self._client_stream_id) else "server"
|
||||
|
@ -701,7 +698,6 @@ class QuicLayer(tunnel.TunnelLayer):
|
|||
self.child_layer = layer.NextLayer(self.context, ask_on_start=True)
|
||||
self._time = time or asyncio.get_running_loop().time
|
||||
self._wakeup_commands: dict[commands.RequestWakeup, float] = dict()
|
||||
self._routes: dict[connection.Address, ConnectionHandler | None] = dict()
|
||||
conn.tls = True
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
|
@ -826,8 +822,7 @@ class QuicLayer(tunnel.TunnelLayer):
|
|||
all_certs: list[x509.Certificate] = []
|
||||
if self.quic.tls._peer_certificate:
|
||||
all_certs.append(self.quic.tls._peer_certificate)
|
||||
if self.quic.tls._peer_certificate_chain:
|
||||
all_certs.extend(self.quic.tls._peer_certificate_chain)
|
||||
all_certs.extend(self.quic.tls._peer_certificate_chain)
|
||||
|
||||
# set the connection's TLS properties
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
|
@ -983,7 +978,7 @@ class ServerQuicLayer(QuicLayer):
|
|||
|
||||
class ClientQuicLayer(QuicLayer):
|
||||
"""
|
||||
This layer establishes QUIC on a single client connection or roams to another connection.
|
||||
This layer establishes QUIC on a single client connection.
|
||||
"""
|
||||
|
||||
server_tls_available: bool
|
||||
|
@ -1091,6 +1086,7 @@ class ClientQuicLayer(QuicLayer):
|
|||
|
||||
# start the client QUIC connection
|
||||
yield from self.start_tls(header.destination_cid)
|
||||
# XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately
|
||||
if not self.conn.connected:
|
||||
return False, "connection closed early"
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from logging import DEBUG, WARNING
|
||||
from logging import DEBUG, ERROR, WARNING
|
||||
import ssl
|
||||
import time
|
||||
from aioquic.buffer import Buffer as QuicBuffer
|
||||
|
@ -35,16 +35,12 @@ def tctx() -> context.Context:
|
|||
)
|
||||
|
||||
|
||||
class InvalidStreamEvent(quic.QuicStreamEvent):
|
||||
pass
|
||||
class DummyLayer(layer.Layer):
|
||||
child_layer: Optional[layer.Layer]
|
||||
|
||||
|
||||
class InvalidEvent(events.Event):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidConnectionCommand(commands.ConnectionCommand):
|
||||
pass
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.child_layer
|
||||
return self.child_layer.handle_event(event)
|
||||
|
||||
|
||||
class TlsEchoLayer(tutils.EchoLayer):
|
||||
|
@ -58,10 +54,26 @@ class TlsEchoLayer(tutils.EchoLayer):
|
|||
yield commands.SendData(
|
||||
event.connection, f"open-connection failed: {err}".encode()
|
||||
)
|
||||
elif isinstance(event, events.DataReceived) and event.data == b"close-connection":
|
||||
yield commands.CloseConnection(event.connection)
|
||||
elif isinstance(event, events.DataReceived) and event.data == b"close-connection-error":
|
||||
yield quic.CloseQuicConnection(event.connection, ~0, None, "error")
|
||||
elif isinstance(event, events.DataReceived) and event.data == b"stop-stream":
|
||||
yield quic.StopQuicStream(event.connection, 24, 123)
|
||||
elif isinstance(event, events.DataReceived) and event.data == b"invalid-command":
|
||||
class InvalidConnectionCommand(commands.ConnectionCommand):
|
||||
pass
|
||||
yield InvalidConnectionCommand(event.connection)
|
||||
elif isinstance(event, events.DataReceived) and event.data == b"invalid-stream-command":
|
||||
class InvalidStreamCommand(quic.QuicStreamCommand):
|
||||
pass
|
||||
yield InvalidStreamCommand(event.connection, 42)
|
||||
elif isinstance(event, quic.QuicConnectionClosed):
|
||||
self.closed = event
|
||||
elif isinstance(event, quic.QuicStreamDataReceived):
|
||||
yield quic.SendQuicStreamData(event.connection, event.stream_id, event.data, event.end_stream)
|
||||
elif isinstance(event, quic.QuicStreamReset):
|
||||
yield quic.ResetQuicStream(event.connection, event.stream_id, event.error_code)
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
|
@ -124,7 +136,6 @@ def test_secrets_logger(value: str):
|
|||
|
||||
|
||||
class TestParseClientHello:
|
||||
|
||||
def test_input(self):
|
||||
assert quic.quic_parse_client_hello(client_hello).sni == "example.com"
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -154,13 +165,9 @@ class TestParseClientHello:
|
|||
with pytest.raises(ValueError, match="Conn err"):
|
||||
quic.quic_parse_client_hello(client_hello)
|
||||
|
||||
def test_no_return(self, monkeypatch):
|
||||
def do_nothing(self, data, addr, now):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(QuicConnection, "receive_datagram", do_nothing)
|
||||
def test_no_return(self):
|
||||
with pytest.raises(ValueError, match="No ClientHello"):
|
||||
quic.quic_parse_client_hello(client_hello)
|
||||
quic.quic_parse_client_hello(client_hello[0:1200] + b'\x00' + client_hello[1200:])
|
||||
|
||||
|
||||
class TestQuicStreamLayer:
|
||||
|
@ -320,6 +327,8 @@ class TestRawQuicLayer:
|
|||
>> tutils.reply(None)
|
||||
)
|
||||
with pytest.raises(AssertionError, match="Unexpected stream event"):
|
||||
class InvalidStreamEvent(quic.QuicStreamEvent):
|
||||
pass
|
||||
playbook >> InvalidStreamEvent(tctx.client, 0)
|
||||
assert playbook
|
||||
|
||||
|
@ -331,6 +340,8 @@ class TestRawQuicLayer:
|
|||
>> tutils.reply(None)
|
||||
)
|
||||
with pytest.raises(AssertionError, match="Unexpected event"):
|
||||
class InvalidEvent(events.Event):
|
||||
pass
|
||||
playbook >> InvalidEvent()
|
||||
assert playbook
|
||||
|
||||
|
@ -383,6 +394,125 @@ class TestRawQuicLayer:
|
|||
assert playbook
|
||||
|
||||
|
||||
class MockQuic(QuicConnection):
|
||||
def __init__(self, event) -> None:
|
||||
super().__init__(configuration=QuicConfiguration(is_client=True))
|
||||
self.event = event
|
||||
|
||||
def next_event(self):
|
||||
event = self.event
|
||||
self.event = None
|
||||
return event
|
||||
|
||||
def datagrams_to_send(self, now: float):
|
||||
return []
|
||||
|
||||
def get_timer(self):
|
||||
return None
|
||||
|
||||
|
||||
def make_mock_quic(
|
||||
tctx: context.Context,
|
||||
event: Optional[quic_events.QuicEvent] = None,
|
||||
established: bool = True
|
||||
) -> tuple[tutils.Playbook, MockQuic]:
|
||||
tctx.client.state = connection.ConnectionState.CLOSED
|
||||
quic_layer = quic.QuicLayer(tctx, tctx.client, time=lambda: 0)
|
||||
quic_layer.child_layer = TlsEchoLayer(tctx)
|
||||
mock = MockQuic(event)
|
||||
quic_layer.quic = mock
|
||||
quic_layer.tunnel_state = (
|
||||
tls.tunnel.TunnelState.OPEN
|
||||
if established else
|
||||
tls.tunnel.TunnelState.ESTABLISHING
|
||||
)
|
||||
return tutils.Playbook(quic_layer), mock
|
||||
|
||||
|
||||
class TestQuicLayer:
|
||||
@pytest.mark.parametrize("established", [True, False])
|
||||
def test_invalid_event(self, tctx: context.Context, established: bool):
|
||||
class InvalidEvent(quic_events.QuicEvent):
|
||||
pass
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, event=InvalidEvent(), established=established
|
||||
)
|
||||
with pytest.raises(AssertionError, match="Unexpected event"):
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"")
|
||||
)
|
||||
|
||||
def test_invalid_stream_command(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.DatagramFrameReceived(b"invalid-stream-command")
|
||||
)
|
||||
with pytest.raises(AssertionError, match="Unexpected stream command"):
|
||||
assert (playbook >> events.DataReceived(tctx.client, b""))
|
||||
|
||||
def test_close(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.DatagramFrameReceived(b"close-connection")
|
||||
)
|
||||
assert not conn._close_event
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"")
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
)
|
||||
assert conn._close_event
|
||||
assert conn._close_event.error_code == 0
|
||||
|
||||
def test_close_error(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.DatagramFrameReceived(b"close-connection-error")
|
||||
)
|
||||
assert not conn._close_event
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"")
|
||||
<< quic.CloseQuicConnection(tctx.client, ~0, None, "error")
|
||||
)
|
||||
assert conn._close_event
|
||||
assert conn._close_event.error_code == ~0
|
||||
|
||||
def test_datagram(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.DatagramFrameReceived(b"packet")
|
||||
)
|
||||
assert not conn._datagrams_pending
|
||||
assert (playbook >> events.DataReceived(tctx.client, b""))
|
||||
assert len(conn._datagrams_pending) == 1
|
||||
assert conn._datagrams_pending[0] == b"packet"
|
||||
|
||||
def test_stream_data(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.StreamDataReceived(b"packet", False, 42)
|
||||
)
|
||||
assert 42 not in conn._streams
|
||||
assert (playbook >> events.DataReceived(tctx.client, b""))
|
||||
assert b"packet" == conn._streams[42].sender._buffer
|
||||
|
||||
def test_stream_reset(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.StreamReset(123, 42)
|
||||
)
|
||||
assert 42 not in conn._streams
|
||||
assert (playbook >> events.DataReceived(tctx.client, b""))
|
||||
assert conn._streams[42].sender.reset_pending
|
||||
assert conn._streams[42].sender._reset_error_code == 123
|
||||
|
||||
def test_stream_stop(self, tctx: context.Context):
|
||||
playbook, conn = make_mock_quic(
|
||||
tctx, quic_events.DatagramFrameReceived(b"stop-stream")
|
||||
)
|
||||
assert 24 not in conn._streams
|
||||
conn._get_or_create_stream_for_send(24)
|
||||
assert (playbook >> events.DataReceived(tctx.client, b""))
|
||||
assert conn._streams[24].receiver.stop_pending
|
||||
assert conn._streams[24].receiver._stop_error_code == 123
|
||||
|
||||
|
||||
class SSLTest:
|
||||
"""Helper container for Python's builtin SSL object."""
|
||||
|
||||
|
@ -391,6 +521,7 @@ class SSLTest:
|
|||
server_side: bool = False,
|
||||
alpn: Optional[list[str]] = None,
|
||||
sni: Optional[str] = "example.mitmproxy.org",
|
||||
version: Optional[int] = None,
|
||||
):
|
||||
self.ctx = QuicConfiguration(
|
||||
is_client=not server_side,
|
||||
|
@ -420,6 +551,9 @@ class SSLTest:
|
|||
|
||||
self.ctx.server_name = None if server_side else sni
|
||||
|
||||
if version is not None:
|
||||
self.ctx.supported_versions = [version]
|
||||
|
||||
self.now = 0.0
|
||||
self.address = (sni, 443)
|
||||
self.quic = None if server_side else QuicConnection(configuration=self.ctx)
|
||||
|
@ -681,14 +815,13 @@ class TestServerTLS:
|
|||
|
||||
|
||||
def make_client_tls_layer(
|
||||
tctx: context.Context, **kwargs
|
||||
) -> tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]:
|
||||
tctx: context.Context, no_server: bool = False, **kwargs
|
||||
) -> tuple[tutils.Playbook, quic.ClientQuicLayer, SSLTest]:
|
||||
tssl_client = SSLTest(**kwargs)
|
||||
tssl_client.quic.connect(tssl_client.address, 0)
|
||||
|
||||
# This is a bit contrived as the client layer expects a server layer as parent.
|
||||
# We also set child layers manually to avoid NextLayer noise.
|
||||
server_layer = quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now)
|
||||
server_layer = DummyLayer(tctx) if no_server else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now)
|
||||
client_layer = quic.ClientQuicLayer(tctx, time=lambda: tssl_client.now)
|
||||
server_layer.child_layer = client_layer
|
||||
playbook = tutils.Playbook(server_layer)
|
||||
|
@ -701,6 +834,7 @@ def make_client_tls_layer(
|
|||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
# Start handshake.
|
||||
tssl_client.quic.connect(tssl_client.address, now=tssl_client.now)
|
||||
assert not tssl_client.handshake_completed()
|
||||
|
||||
return playbook, client_layer, tssl_client
|
||||
|
@ -742,6 +876,16 @@ class TestClientTLS:
|
|||
<< commands.SendData(other_server, b"plaintext")
|
||||
)
|
||||
|
||||
# test the close log
|
||||
tssl_client.now = tssl_client.now + 60
|
||||
assert (
|
||||
playbook
|
||||
>> events.Wakeup(playbook.actual[16])
|
||||
<< commands.Log(" >> Wakeup(command=RequestWakeup({'delay': 0.20000000000000004}))", DEBUG)
|
||||
<< commands.Log(" [quic] close_notify Client(client:1234, state=open, tls) (reason=Idle timeout)", DEBUG)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("server_state", ["open", "closed"])
|
||||
def test_server_required(self, tctx: context.Context, server_state: Literal["open", "closed"]):
|
||||
"""
|
||||
|
@ -921,3 +1065,100 @@ class TestClientTLS:
|
|||
)
|
||||
assert not tctx.client.tls_established
|
||||
assert tls_hook_data().conn.error
|
||||
|
||||
def test_server_unavailable_and_no_settings(self, tctx: context.Context):
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
|
||||
|
||||
def require_server_conn(client_hello: tls.ClientHelloData) -> None:
|
||||
client_hello.establish_server_tls_first = True
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl_client.read())
|
||||
<< tls.TlsClienthelloHook(tutils.Placeholder())
|
||||
>> tutils.reply(side_effect=require_server_conn)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> tutils.reply("I cannot open the server, Dave")
|
||||
<< commands.Log(
|
||||
f"Unable to establish QUIC connection with server (I cannot open the server, Dave). "
|
||||
f"Trying to establish QUIC with client anyway. "
|
||||
f"If you plan to redirect requests away from this server, "
|
||||
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
|
||||
)
|
||||
<< quic.QuicStartClientHook(tutils.Placeholder())
|
||||
)
|
||||
tctx.client.state = connection.ConnectionState.CLOSED
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply()
|
||||
<< commands.Log(f"No QUIC context was provided, failing connection.", ERROR)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
<< commands.Log("Client QUIC handshake failed. connection closed early", WARNING)
|
||||
<< tls.TlsFailedClientHook(tutils.Placeholder())
|
||||
)
|
||||
|
||||
def test_no_server_tls(self, tctx: context.Context):
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, no_server=True)
|
||||
|
||||
def require_server_conn(client_hello: tls.ClientHelloData) -> None:
|
||||
client_hello.establish_server_tls_first = True
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl_client.read())
|
||||
<< tls.TlsClienthelloHook(tutils.Placeholder())
|
||||
>> tutils.reply(side_effect=require_server_conn)
|
||||
<< commands.Log(
|
||||
f"Unable to establish QUIC connection with server (No server QUIC available.). "
|
||||
f"Trying to establish QUIC with client anyway. "
|
||||
f"If you plan to redirect requests away from this server, "
|
||||
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
|
||||
)
|
||||
<< quic.QuicStartClientHook(tutils.Placeholder())
|
||||
)
|
||||
|
||||
def test_version_negotiation(self, tctx: context.Context):
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, version=0)
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl_client.read())
|
||||
<< commands.SendData(tctx.client, tutils.Placeholder())
|
||||
)
|
||||
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
|
||||
|
||||
def test_non_init_clienthello(self, tctx: context.Context):
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
|
||||
data = (
|
||||
b'\xc2\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb\x00@a\x98\x19\xf95t\xad-\x1c\\a\xdd\x8c\xd0\x15F'
|
||||
b'\xdf\xdc\x87cb\x1eu\xb0\x95*\xac\xa8\xf7a \xb8\nQ\xbd=\xf5x\xca\r\xe6\x8b\x05 w\x9f\xcd\x8d\xcb\xa0\x06\x1e \x8d.\x8f'
|
||||
b'T\xda\x12et\xe4\x83\x93X<o\xad\xd5%&\x8f7\xa6>\x8aa\xd1\xb2\x18\xb6\xa7\xf50y\x9b\xc5T\xe1\x87\xdd\x9fqv\xb0\x90\xa7s'
|
||||
b'\xee\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb@a*.\xa8j\x90\x1b\x1a\x7fZ\x04\x0b\\\xc7\x00\x03'
|
||||
b'\xd7sC\xf8G\x84\x1e\xba\xcf\x08Z\xdd\x98+\xaa\x98J\xca\xe3\xb7u1\x89\x00\xdf\x8e\x16`\xd9^\xc0@i\x1a\x10\x99\r\xd8'
|
||||
b'\x1dv3\xc6\xb8"\xb9\xa8F\x95K\x9a/\xbc\'\xd8\xd8\x94\x8f\xe7B/\x05\x9d\xfb\x80\xa9\xda@\xe6\xb0J\xfe\xe0\x0f\x02L}'
|
||||
b'\xd9\xed\xd2L\xa7\xcf'
|
||||
)
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, data)
|
||||
<< commands.Log(f"Client QUIC handshake failed. Invalid handshake received, roaming not supported. ({data.hex()})", WARNING)
|
||||
<< tls.TlsFailedClientHook(tutils.Placeholder())
|
||||
)
|
||||
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
|
||||
|
||||
def test_invalid_clienthello(self, tctx: context.Context):
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
|
||||
data = client_hello[0:1200] + b'\x00' + client_hello[1200:]
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, data)
|
||||
<< commands.Log(f"Client QUIC handshake failed. Cannot parse ClientHello: No ClientHello returned. ({data.hex()})", WARNING)
|
||||
<< tls.TlsFailedClientHook(tutils.Placeholder())
|
||||
)
|
||||
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
|
||||
|
||||
def test_tls_reset(self, tctx: context.Context):
|
||||
tctx.client.tls = True
|
||||
tctx.client.sni = "some"
|
||||
DummyLayer(tctx)
|
||||
quic.ClientQuicLayer(tctx, time=lambda: 0)
|
||||
assert tctx.client.sni is None
|
||||
|
|
Loading…
Reference in New Issue