udp: listen on both ipv4 and ipv6 by default (#6206)

* udp: listen on both ipv4 and ipv6 by default

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Maximilian Hils 2023-06-27 10:02:23 +02:00 committed by GitHub
parent ff0155b1f7
commit 06ba039e4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 36 deletions

View File

@ -252,10 +252,7 @@ async def start_server(
) -> UdpServer:
"""UDP variant of asyncio.start_server."""
if host == "":
# binding to an empty string does not work on Windows or Ubuntu.
host = "0.0.0.0"
assert host, "Cannot bind to an empty host for UDP sockets on Windows or Ubuntu."
loop = asyncio.get_running_loop()
_, protocol = await loop.create_datagram_endpoint(
lambda: UdpServer(datagram_received_cb, loop),

View File

@ -251,26 +251,28 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
_server: asyncio.Server | udp.UdpServer | None = None
_servers: list[asyncio.Server | udp.UdpServer]
def __init__(self, *args, **kwargs) -> None:
self._servers = []
super().__init__(*args, **kwargs)
@property
def is_running(self) -> bool:
return self._server is not None
return bool(self._servers)
@property
def listen_addrs(self) -> tuple[Address, ...]:
if self._server is not None:
return tuple(s.getsockname() for s in self._server.sockets)
else:
return tuple()
return tuple(
sock.getsockname() for serv in self._servers for sock in serv.sockets
)
async def _start(self) -> None:
assert self._server is None
assert not self._servers
host = self.mode.listen_host(ctx.options.listen_host)
port = self.mode.listen_port(ctx.options.listen_port)
try:
self._server = await self.listen(host, port)
self._listen_addrs = tuple(s.getsockname() for s in self._server.sockets)
self._servers = await self.listen(host, port)
except OSError as e:
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None:
@ -281,15 +283,18 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
raise OSError(e.errno, message, e.filename) from e
async def _stop(self) -> None:
assert self._server is not None
assert self._servers
try:
self._server.close()
await self._server.wait_closed()
for s in self._servers:
s.close()
await asyncio.gather(*[s.wait_closed() for s in self._servers])
finally:
# we always reset _server and ignore failures
self._server = None
self._servers = []
async def listen(self, host: str, port: int) -> asyncio.Server | udp.UdpServer:
async def listen(
self, host: str, port: int
) -> list[asyncio.Server | udp.UdpServer]:
if self.mode.transport_protocol == "tcp":
# workaround for https://github.com/python/cpython/issues/89856:
# We want both IPv4 and IPv6 sockets to bind to the same port.
@ -301,23 +306,42 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
s.bind(("", 0))
fixed_port = s.getsockname()[1]
s.close()
return await asyncio.start_server(
self.handle_tcp_connection, host, fixed_port
)
return [
await asyncio.start_server(
self.handle_tcp_connection, host, fixed_port
)
]
except Exception as e:
logger.debug(
f"Failed to listen on a single port ({e!r}), falling back to default behavior."
)
return await asyncio.start_server(self.handle_tcp_connection, host, port)
return [await asyncio.start_server(self.handle_tcp_connection, host, port)]
elif self.mode.transport_protocol == "udp":
# create_datagram_endpoint only creates one socket, so the workaround above doesn't apply
# NOTE once we do dual servers, we should consider creating sockets manually to ensure
# both TCP and UDP listen to the same IPs and same ports
return await udp.start_server(
self.handle_udp_datagram,
host,
port,
)
# create_datagram_endpoint only creates one (non-dual-stack) socket, so we spawn two servers instead.
if not host:
ipv4 = await udp.start_server(
self.handle_udp_datagram,
"0.0.0.0",
port,
)
try:
ipv6 = await udp.start_server(
self.handle_udp_datagram,
"::",
port or ipv4.sockets[0].getsockname()[1],
)
except Exception: # pragma: no cover
logger.debug("Failed to listen on '::', listening on IPv4 only.")
return [ipv4]
else: # pragma: no cover
return [ipv4, ipv6]
return [
await udp.start_server(
self.handle_udp_datagram,
host,
port,
)
]
else:
raise AssertionError(self.mode.transport_protocol)

View File

@ -72,10 +72,9 @@ async def test_client_server():
async def test_bind_emptystr():
server = await start_server(lambda *_: None, "", 0)
assert server.sockets[0].getsockname()
server.close()
await server.wait_closed()
# this should be handled by the caller, we just raise visibly here.
with pytest.raises(AssertionError):
await start_server(lambda *_: None, "", 0)
async def test_reader(caplog_async):

View File

@ -57,7 +57,7 @@ async def test_last_exception_and_running(monkeypatch):
await inst1.start()
assert inst1.last_exception is None
assert inst1.is_running
monkeypatch.setattr(inst1._server, "wait_closed", _raise)
monkeypatch.setattr(inst1._servers[0], "wait_closed", _raise)
with pytest.raises(type(err), match=str(err)):
await inst1.stop()
assert inst1.last_exception is err
@ -305,6 +305,33 @@ async def test_udp_connection_reuse(monkeypatch):
assert len(inst.manager.connections) == 1
async def test_udp_dual_stack(caplog_async):
caplog_async.set_level("DEBUG")
manager = MagicMock()
manager.connections = {}
with taddons.context():
inst = ServerInstance.make("dns@:0", manager)
await inst.start()
assert await caplog_async.await_log("server listening")
_, port, *_ = inst.listen_addrs[0]
reader, writer = await udp.open_connection("127.0.0.1", port)
writer.write(b"\x00\x00\x01")
assert await caplog_async.await_log("sent an invalid message")
writer.close()
if "listening on IPv4 only" not in caplog_async.caplog.text:
caplog_async.clear()
reader, writer = await udp.open_connection("::1", port)
writer.write(b"\x00\x00\x01")
assert await caplog_async.await_log("sent an invalid message")
writer.close()
await inst.stop()
assert await caplog_async.await_log("stopped")
@pytest.fixture()
def patched_os_proxy(monkeypatch):
start_os_proxy = AsyncMock()

View File

@ -183,7 +183,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
sock2.getsockname.return_value = ("::1", 8080)
server = Mock()
server.sockets = [sock1, sock2]
si1._server = server
si1._servers = [server]
si2 = ServerInstance.make("reverse:example.com", m.proxyserver)
si2.last_exception = RuntimeError("I failed somehow.")
si3 = ServerInstance.make("socks5", m.proxyserver)