uvloop/tests/test_udp.py

403 lines
14 KiB
Python

import asyncio
import os
import socket
import sys
import tempfile
import unittest
import uuid
from uvloop import _testbase as tb
class MyDatagramProto(asyncio.DatagramProtocol):
done = None
def __init__(self, loop=None):
self.state = 'INITIAL'
self.nbytes = 0
if loop is not None:
self.done = asyncio.Future(loop=loop)
def connection_made(self, transport):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'INITIALIZED'
def datagram_received(self, data, addr):
assert self.state == 'INITIALIZED', self.state
self.nbytes += len(data)
def error_received(self, exc):
assert self.state == 'INITIALIZED', self.state
raise exc
def connection_lost(self, exc):
assert self.state == 'INITIALIZED', self.state
self.state = 'CLOSED'
if self.done:
self.done.set_result(None)
class _TestUDP:
def _test_create_datagram_endpoint_addrs(self, family, lc_addr):
class TestMyDatagramProto(MyDatagramProto):
def __init__(inner_self):
super().__init__(loop=self.loop)
def datagram_received(self, data, addr):
super().datagram_received(data, addr)
self.transport.sendto(b'resp:' + data, addr)
coro = self.loop.create_datagram_endpoint(
TestMyDatagramProto,
local_addr=lc_addr,
family=family)
s_transport, server = self.loop.run_until_complete(coro)
remote_addr = s_transport.get_extra_info('sockname')
host, port, *_ = remote_addr
self.assertIsInstance(server, TestMyDatagramProto)
self.assertEqual('INITIALIZED', server.state)
self.assertIs(server.transport, s_transport)
extra = {}
if hasattr(socket, 'SO_REUSEPORT'):
extra['reuse_port'] = True
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
family=family,
remote_addr=(host, port),
**extra)
transport, client = self.loop.run_until_complete(coro)
self.assertIsInstance(client, MyDatagramProto)
self.assertEqual('INITIALIZED', client.state)
self.assertIs(client.transport, transport)
transport.sendto(b'xxx')
tb.run_until(self.loop, lambda: server.nbytes)
self.assertEqual(3, server.nbytes)
tb.run_until(self.loop, lambda: client.nbytes)
# received
self.assertEqual(8, client.nbytes)
# https://github.com/MagicStack/uvloop/issues/319
# uvloop should behave the same as asyncio when given remote_addr
transport.sendto(b'xxx', remote_addr)
tb.run_until(
self.loop, lambda: server.nbytes > 3 or client.done.done())
self.assertEqual(6, server.nbytes)
tb.run_until(self.loop, lambda: client.nbytes > 8)
# received
self.assertEqual(16, client.nbytes)
# reject sendto with a different port
with self.assertRaisesRegex(
ValueError, "Invalid address.*" + repr(remote_addr)
):
bad_addr = list(remote_addr)
bad_addr[1] += 1
bad_addr = tuple(bad_addr)
transport.sendto(b"xxx", bad_addr)
# reject sento with unresolved hostname
if remote_addr[0] != lc_addr[0]:
with self.assertRaisesRegex(
ValueError, "Invalid address.*" + repr(remote_addr)
):
bad_addr = list(remote_addr)
bad_addr[0] = lc_addr[0]
bad_addr = tuple(bad_addr)
transport.sendto(b"xxx", bad_addr)
# extra info is available
self.assertIsNotNone(transport.get_extra_info('sockname'))
# close connection
transport.close()
self.loop.run_until_complete(client.done)
self.assertEqual('CLOSED', client.state)
server.transport.close()
self.loop.run_until_complete(server.done)
def test_create_datagram_endpoint_addrs_ipv4(self):
self._test_create_datagram_endpoint_addrs(
socket.AF_INET, ('127.0.0.1', 0))
def test_create_datagram_endpoint_addrs_ipv4_nameaddr(self):
self._test_create_datagram_endpoint_addrs(
socket.AF_INET, ('localhost', 0))
def _test_create_datagram_endpoint_addrs_ipv6(self):
self._test_create_datagram_endpoint_addrs(
socket.AF_INET6, ('::1', 0))
def test_create_datagram_endpoint_ipv6_family(self):
class TestMyDatagramProto(MyDatagramProto):
def __init__(inner_self):
super().__init__(loop=self.loop)
def datagram_received(self, data, addr):
super().datagram_received(data, addr)
self.transport.sendto(b'resp:' + data, addr)
coro = self.loop.create_datagram_endpoint(
TestMyDatagramProto, local_addr=None, family=socket.AF_INET6)
s_transport = None
try:
s_transport, server = self.loop.run_until_complete(coro)
finally:
if s_transport:
s_transport.close()
# let it close
self.loop.run_until_complete(
asyncio.sleep(0.1))
def test_create_datagram_endpoint_sock(self):
sock = None
local_address = ('127.0.0.1', 0)
infos = self.loop.run_until_complete(
self.loop.getaddrinfo(
*local_address, type=socket.SOCK_DGRAM))
for family, type, proto, cname, address in infos:
try:
sock = socket.socket(family=family, type=type, proto=proto)
sock.setblocking(False)
sock.bind(address)
except Exception:
pass
else:
break
else:
assert False, 'Can not create socket.'
with sock:
f = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(pr, MyDatagramProto)
tr.close()
self.loop.run_until_complete(pr.done)
def test_create_datagram_endpoint_sock_unix_domain(self):
class Proto(asyncio.DatagramProtocol):
done = None
def __init__(self, loop):
self.state = 'INITIAL'
self.addrs = set()
self.done = asyncio.Future(loop=loop)
self.data = b''
def connection_made(self, transport):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'INITIALIZED'
def datagram_received(self, data, addr):
assert self.state == 'INITIALIZED', self.state
self.addrs.add(addr)
self.data += data
if self.data == b'STOP' and not self.done.done():
self.done.set_result(True)
def error_received(self, exc):
assert self.state == 'INITIALIZED', self.state
if not self.done.done():
self.done.set_exception(exc or RuntimeError())
def connection_lost(self, exc):
assert self.state == 'INITIALIZED', self.state
self.state = 'CLOSED'
if self.done and not self.done.done():
self.done.set_result(None)
tmp_file = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_DGRAM)
sock.bind(tmp_file)
with sock:
pr = Proto(loop=self.loop)
f = self.loop.create_datagram_endpoint(
lambda: pr, sock=sock)
tr, pr_prime = self.loop.run_until_complete(f)
self.assertIs(pr, pr_prime)
tmp_file2 = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
sock2 = socket.socket(socket.AF_UNIX, type=socket.SOCK_DGRAM)
sock2.bind(tmp_file2)
with sock2:
f2 = self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol, sock=sock2)
tr2, pr2 = self.loop.run_until_complete(f2)
tr2.sendto(b'STOP', tmp_file)
self.loop.run_until_complete(pr.done)
tr.close()
tr2.close()
# Let transports close
self.loop.run_until_complete(asyncio.sleep(0.2))
self.assertIn(tmp_file2, pr.addrs)
def test_create_datagram_1(self):
server_addr = ('127.0.0.1', 8888)
client_addr = ('127.0.0.1', 0)
async def run():
server_transport, client_protocol = \
await self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
local_addr=server_addr)
client_transport, client_conn = \
await self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
remote_addr=server_addr,
local_addr=client_addr)
client_transport.close()
server_transport.close()
# let transports close
await asyncio.sleep(0.1)
self.loop.run_until_complete(run())
def test_socketpair(self):
peername = asyncio.Future(loop=self.loop)
class Proto(MyDatagramProto):
def datagram_received(self, data, addr):
super().datagram_received(data, addr)
peername.set_result(addr)
s1, s2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM, 0)
with s1, s2:
f = self.loop.create_datagram_endpoint(
lambda: Proto(loop=self.loop), sock=s1)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(pr, Proto)
s2.send(b'hello, socketpair')
addr = self.loop.run_until_complete(
asyncio.wait_for(peername, 1))
if sys.platform.startswith('linux'):
self.assertEqual(addr, None)
else:
self.assertEqual(addr, '')
self.assertEqual(pr.nbytes, 17)
if not self.is_asyncio_loop():
# asyncio doesn't support sendto(xx) on UDP sockets
# https://git.io/Jfqbw
data = b'from uvloop'
tr.sendto(data)
result = self.loop.run_until_complete(asyncio.wait_for(
self.loop.run_in_executor(None, s2.recv, 1024),
1))
self.assertEqual(data, result)
tr.close()
self.loop.run_until_complete(pr.done)
def _skip_create_datagram_endpoint_reuse_addr(self):
if self.implementation == 'asyncio':
if sys.version_info[:2] >= (3, 11):
raise unittest.SkipTest()
if (3, 8, 0) <= sys.version_info < (3, 8, 1):
raise unittest.SkipTest()
def test_create_datagram_endpoint_reuse_address_error(self):
# bpo-37228: Ensure that explicit passing of `reuse_address=True`
# raises an error, as it is not safe to use SO_REUSEADDR when using UDP
self._skip_create_datagram_endpoint_reuse_addr()
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('127.0.0.1', 0),
reuse_address=True)
with self.assertRaises(ValueError):
self.loop.run_until_complete(coro)
def test_create_datagram_endpoint_reuse_address_warning(self):
# bpo-37228: Deprecate *reuse_address* parameter
self._skip_create_datagram_endpoint_reuse_addr()
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('127.0.0.1', 0),
reuse_address=False)
with self.assertWarns(DeprecationWarning):
tr, pr = self.loop.run_until_complete(coro)
tr.close()
self.loop.run_until_complete(pr.done)
class Test_UV_UDP(_TestUDP, tb.UVTestCase):
def test_create_datagram_endpoint_wrong_sock(self):
sock = socket.socket(socket.AF_INET)
with sock:
coro = self.loop.create_datagram_endpoint(lambda: None, sock=sock)
with self.assertRaisesRegex(ValueError,
'A UDP Socket was expected'):
self.loop.run_until_complete(coro)
def test_udp_sendto_dns(self):
coro = self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
local_addr=('127.0.0.1', 0),
family=socket.AF_INET)
s_transport, server = self.loop.run_until_complete(coro)
with self.assertRaisesRegex(ValueError, 'DNS lookup'):
s_transport.sendto(b'aaaa', ('example.com', 80))
with self.assertRaisesRegex(ValueError, 'socket family mismatch'):
s_transport.sendto(b'aaaa', ('::1', 80))
s_transport.close()
self.loop.run_until_complete(asyncio.sleep(0.01))
def test_send_after_close(self):
coro = self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
local_addr=('127.0.0.1', 0),
family=socket.AF_INET)
s_transport, _ = self.loop.run_until_complete(coro)
s_transport.close()
s_transport.sendto(b'aaaa', ('127.0.0.1', 80))
self.loop.run_until_complete(asyncio.sleep(0.01))
s_transport.sendto(b'aaaa', ('127.0.0.1', 80))
@unittest.skipUnless(tb.has_IPv6, 'no IPv6')
def test_create_datagram_endpoint_addrs_ipv6(self):
self._test_create_datagram_endpoint_addrs_ipv6()
class Test_AIO_UDP(_TestUDP, tb.AIOTestCase):
@unittest.skipUnless(tb.has_IPv6, 'no IPv6')
def test_create_datagram_endpoint_addrs_ipv6(self):
self._test_create_datagram_endpoint_addrs_ipv6()