import asyncio import pickle import select import socket import sys import time import unittest from uvloop import _testbase as tb _SIZE = 1024 * 1024 class _TestSockets: async def recv_all(self, sock, nbytes): buf = b'' while len(buf) < nbytes: buf += await self.loop.sock_recv(sock, nbytes - len(buf)) return buf def test_socket_connect_recv_send(self): if sys.version_info[:3] >= (3, 8, 0): # @asyncio.coroutine is deprecated in 3.8 raise unittest.SkipTest() def srv_gen(sock): sock.send(b'helo') data = sock.recv_all(4 * _SIZE) self.assertEqual(data, b'ehlo' * _SIZE) sock.send(b'O') sock.send(b'K') # We use @asyncio.coroutine & `yield from` to test # the compatibility of Cython's 'async def' coroutines. @asyncio.coroutine def client(sock, addr): yield from self.loop.sock_connect(sock, addr) data = yield from self.recv_all(sock, 4) self.assertEqual(data, b'helo') yield from self.loop.sock_sendall(sock, b'ehlo' * _SIZE) data = yield from self.recv_all(sock, 2) self.assertEqual(data, b'OK') with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) self.loop.run_until_complete(client(sock, srv.addr)) def test_socket_accept_recv_send(self): async def server(): sock = socket.socket() sock.setblocking(False) with sock: sock.bind(('127.0.0.1', 0)) sock.listen() fut = self.loop.run_in_executor(None, client, sock.getsockname()) client_sock, _ = await self.loop.sock_accept(sock) with client_sock: data = await self.recv_all(client_sock, _SIZE) self.assertEqual(data, b'a' * _SIZE) await fut def client(addr): sock = socket.socket() with sock: sock.connect(addr) sock.sendall(b'a' * _SIZE) self.loop.run_until_complete(server()) def test_socket_failed_connect(self): sock = socket.socket() with sock: sock.bind(('127.0.0.1', 0)) addr = sock.getsockname() async def run(): sock = socket.socket() with sock: sock.setblocking(False) with self.assertRaises(ConnectionRefusedError): await self.loop.sock_connect(sock, addr) self.loop.run_until_complete(run()) @unittest.skipUnless(tb.has_IPv6, 'no IPv6') def test_socket_ipv6_addr(self): server_sock = socket.socket(socket.AF_INET6) with server_sock: server_sock.bind(('::1', 0)) addr = server_sock.getsockname() # tuple of 4 elements for IPv6 async def run(): sock = socket.socket(socket.AF_INET6) with sock: sock.setblocking(False) # Check that sock_connect accepts 4-element address tuple # for IPv6 sockets. f = self.loop.sock_connect(sock, addr) try: await asyncio.wait_for(f, timeout=0.1) except (asyncio.TimeoutError, ConnectionRefusedError): # TimeoutError is expected. pass self.loop.run_until_complete(run()) def test_socket_ipv4_nameaddr(self): async def run(): sock = socket.socket(socket.AF_INET) with sock: sock.setblocking(False) await self.loop.sock_connect(sock, ('localhost', 0)) with self.assertRaises(OSError): # Regression test: sock_connect(sock) wasn't calling # getaddrinfo() with `family=sock.family`, which resulted # in `socket.connect()` being called with an IPv6 address # for IPv4 sockets, which used to cause a TypeError. # Here we expect that that is fixed so we should get an # OSError instead. self.loop.run_until_complete(run()) def test_socket_blocking_error(self): self.loop.set_debug(True) sock = socket.socket() with sock: with self.assertRaisesRegex(ValueError, 'must be non-blocking'): self.loop.run_until_complete( self.loop.sock_recv(sock, 0)) with self.assertRaisesRegex(ValueError, 'must be non-blocking'): self.loop.run_until_complete( self.loop.sock_sendall(sock, b'')) with self.assertRaisesRegex(ValueError, 'must be non-blocking'): self.loop.run_until_complete( self.loop.sock_accept(sock)) with self.assertRaisesRegex(ValueError, 'must be non-blocking'): self.loop.run_until_complete( self.loop.sock_connect(sock, (b'', 0))) def test_socket_fileno(self): rsock, wsock = socket.socketpair() f = asyncio.Future(loop=self.loop) def reader(): rsock.recv(100) # We are done: unregister the file descriptor self.loop.remove_reader(rsock) f.set_result(None) def writer(): wsock.send(b'abc') self.loop.remove_writer(wsock) with rsock, wsock: self.loop.add_reader(rsock, reader) self.loop.add_writer(wsock, writer) self.loop.run_until_complete(f) def test_socket_sync_remove_and_immediately_close(self): # Test that it's OK to close the socket right after calling # `remove_reader`. sock = socket.socket() with sock: cb = lambda: None sock.bind(('127.0.0.1', 0)) sock.listen(0) fd = sock.fileno() self.loop.add_reader(fd, cb) self.loop.run_until_complete(asyncio.sleep(0.01)) self.loop.remove_reader(fd) sock.close() self.assertEqual(sock.fileno(), -1) self.loop.run_until_complete(asyncio.sleep(0.01)) def test_sock_cancel_add_reader_race(self): if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): # asyncio 3.8.x has a regression; fixed in 3.9.0 # tracked in https://bugs.python.org/issue30064 raise unittest.SkipTest() srv_sock_conn = None async def server(): nonlocal srv_sock_conn sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock_server.setblocking(False) with sock_server: sock_server.bind(('127.0.0.1', 0)) sock_server.listen() fut = asyncio.ensure_future( client(sock_server.getsockname())) srv_sock_conn, _ = await self.loop.sock_accept(sock_server) srv_sock_conn.setsockopt( socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) with srv_sock_conn: await fut async def client(addr): sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock_client.setblocking(False) with sock_client: await self.loop.sock_connect(sock_client, addr) _, pending_read_futs = await asyncio.wait( [ asyncio.ensure_future( self.loop.sock_recv(sock_client, 1) ) ], timeout=1, ) async def send_server_data(): # Wait a little bit to let reader future cancel and # schedule the removal of the reader callback. Right after # "rfut.cancel()" we will call "loop.sock_recv()", which # will add a reader. This will make a race between # remove- and add-reader. await asyncio.sleep(0.1) await self.loop.sock_sendall(srv_sock_conn, b'1') self.loop.create_task(send_server_data()) for rfut in pending_read_futs: rfut.cancel() data = await self.loop.sock_recv(sock_client, 1) self.assertEqual(data, b'1') self.loop.run_until_complete(server()) def test_sock_send_before_cancel(self): if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): # asyncio 3.8.x has a regression; fixed in 3.9.0 # tracked in https://bugs.python.org/issue30064 raise unittest.SkipTest() srv_sock_conn = None async def server(): nonlocal srv_sock_conn sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock_server.setblocking(False) with sock_server: sock_server.bind(('127.0.0.1', 0)) sock_server.listen() fut = asyncio.ensure_future( client(sock_server.getsockname())) srv_sock_conn, _ = await self.loop.sock_accept(sock_server) with srv_sock_conn: await fut async def client(addr): await asyncio.sleep(0.01) sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock_client.setblocking(False) with sock_client: await self.loop.sock_connect(sock_client, addr) _, pending_read_futs = await asyncio.wait( [ asyncio.ensure_future( self.loop.sock_recv(sock_client, 1) ) ], timeout=1, ) # server can send the data in a random time, even before # the previous result future has cancelled. await self.loop.sock_sendall(srv_sock_conn, b'1') for rfut in pending_read_futs: rfut.cancel() data = await self.loop.sock_recv(sock_client, 1) self.assertEqual(data, b'1') self.loop.run_until_complete(server()) class TestUVSockets(_TestSockets, tb.UVTestCase): @unittest.skipUnless(hasattr(select, 'epoll'), 'Linux only test') def test_socket_sync_remove(self): # See https://github.com/MagicStack/uvloop/issues/61 for details sock = socket.socket() epoll = select.epoll.fromfd(self.loop._get_backend_id()) try: cb = lambda: None sock.bind(('127.0.0.1', 0)) sock.listen(0) fd = sock.fileno() self.loop.add_reader(fd, cb) self.loop.run_until_complete(asyncio.sleep(0.01)) self.loop.remove_reader(fd) with self.assertRaises(FileNotFoundError): epoll.modify(fd, 0) finally: sock.close() self.loop.close() epoll.close() def test_add_reader_or_writer_transport_fd(self): def assert_raises(): return self.assertRaisesRegex( RuntimeError, r'File descriptor .* is used by transport') async def runner(): tr, pr = await self.loop.create_connection( lambda: asyncio.Protocol(), sock=rsock) try: cb = lambda: None sock = tr.get_extra_info('socket') with assert_raises(): self.loop.add_reader(sock, cb) with assert_raises(): self.loop.add_reader(sock.fileno(), cb) with assert_raises(): self.loop.remove_reader(sock) with assert_raises(): self.loop.remove_reader(sock.fileno()) with assert_raises(): self.loop.add_writer(sock, cb) with assert_raises(): self.loop.add_writer(sock.fileno(), cb) with assert_raises(): self.loop.remove_writer(sock) with assert_raises(): self.loop.remove_writer(sock.fileno()) finally: tr.close() rsock, wsock = socket.socketpair() try: self.loop.run_until_complete(runner()) finally: rsock.close() wsock.close() def test_pseudosocket(self): def assert_raises(): return self.assertRaisesRegex( RuntimeError, r'File descriptor .* is used by transport') def test_pseudo(real_sock, pseudo_sock, *, is_dup=False): self.assertIn('AF_UNIX', repr(pseudo_sock)) self.assertEqual(pseudo_sock.family, real_sock.family) self.assertEqual(pseudo_sock.proto, real_sock.proto) # Guard against SOCK_NONBLOCK bit in socket.type on Linux. self.assertEqual(pseudo_sock.type & 0xf, real_sock.type & 0xf) with self.assertRaises(TypeError): pickle.dumps(pseudo_sock) na_meths = { 'accept', 'connect', 'connect_ex', 'bind', 'listen', 'makefile', 'sendfile', 'close', 'detach', 'shutdown', 'sendmsg_afalg', 'sendmsg', 'sendto', 'send', 'sendall', 'recv_into', 'recvfrom_into', 'recvmsg_into', 'recvmsg', 'recvfrom', 'recv' } for methname in na_meths: meth = getattr(pseudo_sock, methname) with self.assertRaisesRegex( TypeError, r'.*not support ' + methname + r'\(\) method'): meth() eq_meths = { 'getsockname', 'getpeername', 'get_inheritable', 'gettimeout' } for methname in eq_meths: pmeth = getattr(pseudo_sock, methname) rmeth = getattr(real_sock, methname) # Call 2x to check caching paths self.assertEqual(pmeth(), rmeth()) self.assertEqual(pmeth(), rmeth()) self.assertEqual( pseudo_sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR), 0) if not is_dup: self.assertEqual(pseudo_sock.fileno(), real_sock.fileno()) duped = pseudo_sock.dup() with duped: test_pseudo(duped, pseudo_sock, is_dup=True) with self.assertRaises(TypeError): with pseudo_sock: pass async def runner(): tr, pr = await self.loop.create_connection( lambda: asyncio.Protocol(), sock=rsock) try: sock = tr.get_extra_info('socket') test_pseudo(rsock, sock) finally: tr.close() rsock, wsock = socket.socketpair() try: self.loop.run_until_complete(runner()) finally: rsock.close() wsock.close() def test_socket_connect_and_close(self): def srv_gen(sock): sock.send(b'helo') async def client(sock, addr): f = asyncio.ensure_future(self.loop.sock_connect(sock, addr), loop=self.loop) self.loop.call_soon(sock.close) await f return 'ok' with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) r = self.loop.run_until_complete(client(sock, srv.addr)) self.assertEqual(r, 'ok') def test_socket_recv_and_close(self): def srv_gen(sock): time.sleep(1.2) sock.send(b'helo') async def kill(sock): await asyncio.sleep(0.2) sock.close() async def client(sock, addr): await self.loop.sock_connect(sock, addr) f = asyncio.ensure_future(self.loop.sock_recv(sock, 10), loop=self.loop) self.loop.create_task(kill(sock)) res = await f self.assertEqual(sock.fileno(), -1) return res with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) r = self.loop.run_until_complete(w) self.assertEqual(r, b'helo') def test_socket_recv_into_and_close(self): def srv_gen(sock): time.sleep(1.2) sock.send(b'helo') async def kill(sock): await asyncio.sleep(0.2) sock.close() async def client(sock, addr): await self.loop.sock_connect(sock, addr) data = bytearray(10) with memoryview(data) as buf: f = asyncio.ensure_future(self.loop.sock_recv_into(sock, buf), loop=self.loop) self.loop.create_task(kill(sock)) rcvd = await f data = data[:rcvd] self.assertEqual(sock.fileno(), -1) return bytes(data) with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) r = self.loop.run_until_complete(w) self.assertEqual(r, b'helo') def test_socket_send_and_close(self): ok = False def srv_gen(sock): nonlocal ok b = sock.recv_all(2) if b == b'hi': ok = True sock.send(b'ii') async def client(sock, addr): await self.loop.sock_connect(sock, addr) s2 = sock.dup() # Don't let it drop connection until `f` is done with s2: f = asyncio.ensure_future(self.loop.sock_sendall(sock, b'hi'), loop=self.loop) self.loop.call_soon(sock.close) await f return await self.loop.sock_recv(s2, 2) with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) r = self.loop.run_until_complete(client(sock, srv.addr)) self.assertEqual(r, b'ii') self.assertTrue(ok) def test_socket_close_loop_and_close(self): class Abort(Exception): pass def srv_gen(sock): time.sleep(1.2) async def client(sock, addr): await self.loop.sock_connect(sock, addr) asyncio.ensure_future(self.loop.sock_recv(sock, 10), loop=self.loop) await asyncio.sleep(0.2) raise Abort with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) try: sock = self.loop.run_until_complete(w) except Abort: pass # `loop` still owns `sock`, so closing `sock` shouldn't # do anything. sock.close() self.assertNotEqual(sock.fileno(), -1) # `loop.close()` should io-decref all sockets that the # loop owns, including our `sock`. self.loop.close() self.assertEqual(sock.fileno(), -1) def test_socket_close_remove_reader(self): s = socket.socket() with s: s.setblocking(False) self.loop.add_reader(s, lambda: None) self.loop.remove_reader(s.fileno()) s.close() self.assertEqual(s.fileno(), -1) s = socket.socket() with s: s.setblocking(False) self.loop.add_reader(s.fileno(), lambda: None) self.loop.remove_reader(s) self.assertNotEqual(s.fileno(), -1) s.close() self.assertEqual(s.fileno(), -1) def test_socket_close_remove_writer(self): s = socket.socket() with s: s.setblocking(False) self.loop.add_writer(s, lambda: None) self.loop.remove_writer(s.fileno()) s.close() self.assertEqual(s.fileno(), -1) s = socket.socket() with s: s.setblocking(False) self.loop.add_writer(s.fileno(), lambda: None) self.loop.remove_writer(s) self.assertNotEqual(s.fileno(), -1) s.close() self.assertEqual(s.fileno(), -1) def test_socket_cancel_sock_recv_1(self): def srv_gen(sock): time.sleep(1.2) sock.send(b'helo') async def kill(fut): await asyncio.sleep(0.2) fut.cancel() async def client(sock, addr): await self.loop.sock_connect(sock, addr) f = asyncio.ensure_future(self.loop.sock_recv(sock, 10), loop=self.loop) self.loop.create_task(kill(f)) with self.assertRaises(asyncio.CancelledError): await f sock.close() self.assertEqual(sock.fileno(), -1) with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) self.loop.run_until_complete(w) def test_socket_cancel_sock_recv_2(self): def srv_gen(sock): time.sleep(1.2) sock.send(b'helo') async def kill(fut): await asyncio.sleep(0.5) fut.cancel() async def recv(sock): fut = self.loop.create_task(self.loop.sock_recv(sock, 10)) await asyncio.sleep(0.1) self.loop.remove_reader(sock) sock.close() try: await fut except asyncio.CancelledError: raise finally: sock.close() async def client(sock, addr): await self.loop.sock_connect(sock, addr) f = asyncio.ensure_future(recv(sock)) self.loop.create_task(kill(f)) with self.assertRaises(asyncio.CancelledError): await f sock.close() self.assertEqual(sock.fileno(), -1) with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) self.loop.run_until_complete(w) def test_socket_cancel_sock_sendall(self): def srv_gen(sock): time.sleep(1.2) sock.recv_all(4) async def kill(fut): await asyncio.sleep(0.2) fut.cancel() async def client(sock, addr): await self.loop.sock_connect(sock, addr) f = asyncio.ensure_future( self.loop.sock_sendall(sock, b'helo' * (1024 * 1024 * 50)), loop=self.loop) self.loop.create_task(kill(f)) with self.assertRaises(asyncio.CancelledError): await f sock.close() self.assertEqual(sock.fileno(), -1) # disable slow callback reporting for this test self.loop.slow_callback_duration = 1000.0 with self.tcp_server(srv_gen) as srv: sock = socket.socket() with sock: sock.setblocking(False) c = client(sock, srv.addr) w = asyncio.wait_for(c, timeout=5.0) self.loop.run_until_complete(w) def test_socket_close_many_add_readers(self): s = socket.socket() with s: s.setblocking(False) self.loop.add_reader(s, lambda: None) self.loop.add_reader(s, lambda: None) self.loop.add_reader(s, lambda: None) self.loop.remove_reader(s.fileno()) s.close() self.assertEqual(s.fileno(), -1) s = socket.socket() with s: s.setblocking(False) self.loop.add_reader(s, lambda: None) self.loop.add_reader(s, lambda: None) self.loop.add_reader(s, lambda: None) self.loop.remove_reader(s) s.close() self.assertEqual(s.fileno(), -1) def test_socket_close_many_remove_writers(self): s = socket.socket() with s: s.setblocking(False) self.loop.add_writer(s, lambda: None) self.loop.add_writer(s, lambda: None) self.loop.add_writer(s, lambda: None) self.loop.remove_writer(s.fileno()) s.close() self.assertEqual(s.fileno(), -1) s = socket.socket() with s: s.setblocking(False) self.loop.add_writer(s, lambda: None) self.loop.add_writer(s, lambda: None) self.loop.add_writer(s, lambda: None) self.loop.remove_writer(s) s.close() self.assertEqual(s.fileno(), -1) class TestAIOSockets(_TestSockets, tb.AIOTestCase): pass