udp: listen on both ipv4 and ipv6 by default (#6206)
* udp: listen on both ipv4 and ipv6 by default * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ff0155b1f7
commit
06ba039e4f
|
@ -252,10 +252,7 @@ async def start_server(
|
|||
) -> UdpServer:
|
||||
"""UDP variant of asyncio.start_server."""
|
||||
|
||||
if host == "":
|
||||
# binding to an empty string does not work on Windows or Ubuntu.
|
||||
host = "0.0.0.0"
|
||||
|
||||
assert host, "Cannot bind to an empty host for UDP sockets on Windows or Ubuntu."
|
||||
loop = asyncio.get_running_loop()
|
||||
_, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: UdpServer(datagram_received_cb, loop),
|
||||
|
|
|
@ -251,26 +251,28 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
|||
|
||||
|
||||
class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
||||
_server: asyncio.Server | udp.UdpServer | None = None
|
||||
_servers: list[asyncio.Server | udp.UdpServer]
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._servers = []
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._server is not None
|
||||
return bool(self._servers)
|
||||
|
||||
@property
|
||||
def listen_addrs(self) -> tuple[Address, ...]:
|
||||
if self._server is not None:
|
||||
return tuple(s.getsockname() for s in self._server.sockets)
|
||||
else:
|
||||
return tuple()
|
||||
return tuple(
|
||||
sock.getsockname() for serv in self._servers for sock in serv.sockets
|
||||
)
|
||||
|
||||
async def _start(self) -> None:
|
||||
assert self._server is None
|
||||
assert not self._servers
|
||||
host = self.mode.listen_host(ctx.options.listen_host)
|
||||
port = self.mode.listen_port(ctx.options.listen_port)
|
||||
try:
|
||||
self._server = await self.listen(host, port)
|
||||
self._listen_addrs = tuple(s.getsockname() for s in self._server.sockets)
|
||||
self._servers = await self.listen(host, port)
|
||||
except OSError as e:
|
||||
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
|
||||
if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None:
|
||||
|
@ -281,15 +283,18 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
|||
raise OSError(e.errno, message, e.filename) from e
|
||||
|
||||
async def _stop(self) -> None:
|
||||
assert self._server is not None
|
||||
assert self._servers
|
||||
try:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
for s in self._servers:
|
||||
s.close()
|
||||
await asyncio.gather(*[s.wait_closed() for s in self._servers])
|
||||
finally:
|
||||
# we always reset _server and ignore failures
|
||||
self._server = None
|
||||
self._servers = []
|
||||
|
||||
async def listen(self, host: str, port: int) -> asyncio.Server | udp.UdpServer:
|
||||
async def listen(
|
||||
self, host: str, port: int
|
||||
) -> list[asyncio.Server | udp.UdpServer]:
|
||||
if self.mode.transport_protocol == "tcp":
|
||||
# workaround for https://github.com/python/cpython/issues/89856:
|
||||
# We want both IPv4 and IPv6 sockets to bind to the same port.
|
||||
|
@ -301,23 +306,42 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
|||
s.bind(("", 0))
|
||||
fixed_port = s.getsockname()[1]
|
||||
s.close()
|
||||
return await asyncio.start_server(
|
||||
self.handle_tcp_connection, host, fixed_port
|
||||
)
|
||||
return [
|
||||
await asyncio.start_server(
|
||||
self.handle_tcp_connection, host, fixed_port
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to listen on a single port ({e!r}), falling back to default behavior."
|
||||
)
|
||||
return await asyncio.start_server(self.handle_tcp_connection, host, port)
|
||||
return [await asyncio.start_server(self.handle_tcp_connection, host, port)]
|
||||
elif self.mode.transport_protocol == "udp":
|
||||
# create_datagram_endpoint only creates one socket, so the workaround above doesn't apply
|
||||
# NOTE once we do dual servers, we should consider creating sockets manually to ensure
|
||||
# both TCP and UDP listen to the same IPs and same ports
|
||||
return await udp.start_server(
|
||||
self.handle_udp_datagram,
|
||||
host,
|
||||
port,
|
||||
)
|
||||
# create_datagram_endpoint only creates one (non-dual-stack) socket, so we spawn two servers instead.
|
||||
if not host:
|
||||
ipv4 = await udp.start_server(
|
||||
self.handle_udp_datagram,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
)
|
||||
try:
|
||||
ipv6 = await udp.start_server(
|
||||
self.handle_udp_datagram,
|
||||
"::",
|
||||
port or ipv4.sockets[0].getsockname()[1],
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("Failed to listen on '::', listening on IPv4 only.")
|
||||
return [ipv4]
|
||||
else: # pragma: no cover
|
||||
return [ipv4, ipv6]
|
||||
return [
|
||||
await udp.start_server(
|
||||
self.handle_udp_datagram,
|
||||
host,
|
||||
port,
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise AssertionError(self.mode.transport_protocol)
|
||||
|
||||
|
|
|
@ -72,10 +72,9 @@ async def test_client_server():
|
|||
|
||||
|
||||
async def test_bind_emptystr():
|
||||
server = await start_server(lambda *_: None, "", 0)
|
||||
assert server.sockets[0].getsockname()
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
# this should be handled by the caller, we just raise visibly here.
|
||||
with pytest.raises(AssertionError):
|
||||
await start_server(lambda *_: None, "", 0)
|
||||
|
||||
|
||||
async def test_reader(caplog_async):
|
||||
|
|
|
@ -57,7 +57,7 @@ async def test_last_exception_and_running(monkeypatch):
|
|||
await inst1.start()
|
||||
assert inst1.last_exception is None
|
||||
assert inst1.is_running
|
||||
monkeypatch.setattr(inst1._server, "wait_closed", _raise)
|
||||
monkeypatch.setattr(inst1._servers[0], "wait_closed", _raise)
|
||||
with pytest.raises(type(err), match=str(err)):
|
||||
await inst1.stop()
|
||||
assert inst1.last_exception is err
|
||||
|
@ -305,6 +305,33 @@ async def test_udp_connection_reuse(monkeypatch):
|
|||
assert len(inst.manager.connections) == 1
|
||||
|
||||
|
||||
async def test_udp_dual_stack(caplog_async):
|
||||
caplog_async.set_level("DEBUG")
|
||||
manager = MagicMock()
|
||||
manager.connections = {}
|
||||
|
||||
with taddons.context():
|
||||
inst = ServerInstance.make("dns@:0", manager)
|
||||
await inst.start()
|
||||
assert await caplog_async.await_log("server listening")
|
||||
|
||||
_, port, *_ = inst.listen_addrs[0]
|
||||
reader, writer = await udp.open_connection("127.0.0.1", port)
|
||||
writer.write(b"\x00\x00\x01")
|
||||
assert await caplog_async.await_log("sent an invalid message")
|
||||
writer.close()
|
||||
|
||||
if "listening on IPv4 only" not in caplog_async.caplog.text:
|
||||
caplog_async.clear()
|
||||
reader, writer = await udp.open_connection("::1", port)
|
||||
writer.write(b"\x00\x00\x01")
|
||||
assert await caplog_async.await_log("sent an invalid message")
|
||||
writer.close()
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_os_proxy(monkeypatch):
|
||||
start_os_proxy = AsyncMock()
|
||||
|
|
|
@ -183,7 +183,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
|
|||
sock2.getsockname.return_value = ("::1", 8080)
|
||||
server = Mock()
|
||||
server.sockets = [sock1, sock2]
|
||||
si1._server = server
|
||||
si1._servers = [server]
|
||||
si2 = ServerInstance.make("reverse:example.com", m.proxyserver)
|
||||
si2.last_exception = RuntimeError("I failed somehow.")
|
||||
si3 = ServerInstance.make("socks5", m.proxyserver)
|
||||
|
|
Loading…
Reference in New Issue