uvloop/tests/test_sockets.py

776 lines
26 KiB
Python

import asyncio
import pickle
import select
import socket
import sys
import time
import unittest
from uvloop import _testbase as tb
_SIZE = 1024 * 1024
class _TestSockets:
async def recv_all(self, sock, nbytes):
buf = b''
while len(buf) < nbytes:
buf += await self.loop.sock_recv(sock, nbytes - len(buf))
return buf
def test_socket_connect_recv_send(self):
if sys.version_info[:3] >= (3, 8, 0):
# @asyncio.coroutine is deprecated in 3.8
raise unittest.SkipTest()
def srv_gen(sock):
sock.send(b'helo')
data = sock.recv_all(4 * _SIZE)
self.assertEqual(data, b'ehlo' * _SIZE)
sock.send(b'O')
sock.send(b'K')
# We use @asyncio.coroutine & `yield from` to test
# the compatibility of Cython's 'async def' coroutines.
@asyncio.coroutine
def client(sock, addr):
yield from self.loop.sock_connect(sock, addr)
data = yield from self.recv_all(sock, 4)
self.assertEqual(data, b'helo')
yield from self.loop.sock_sendall(sock, b'ehlo' * _SIZE)
data = yield from self.recv_all(sock, 2)
self.assertEqual(data, b'OK')
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
self.loop.run_until_complete(client(sock, srv.addr))
def test_socket_accept_recv_send(self):
async def server():
sock = socket.socket()
sock.setblocking(False)
with sock:
sock.bind(('127.0.0.1', 0))
sock.listen()
fut = self.loop.run_in_executor(None, client,
sock.getsockname())
client_sock, _ = await self.loop.sock_accept(sock)
with client_sock:
data = await self.recv_all(client_sock, _SIZE)
self.assertEqual(data, b'a' * _SIZE)
await fut
def client(addr):
sock = socket.socket()
with sock:
sock.connect(addr)
sock.sendall(b'a' * _SIZE)
self.loop.run_until_complete(server())
def test_socket_failed_connect(self):
sock = socket.socket()
with sock:
sock.bind(('127.0.0.1', 0))
addr = sock.getsockname()
async def run():
sock = socket.socket()
with sock:
sock.setblocking(False)
with self.assertRaises(ConnectionRefusedError):
await self.loop.sock_connect(sock, addr)
self.loop.run_until_complete(run())
@unittest.skipUnless(tb.has_IPv6, 'no IPv6')
def test_socket_ipv6_addr(self):
server_sock = socket.socket(socket.AF_INET6)
with server_sock:
server_sock.bind(('::1', 0))
addr = server_sock.getsockname() # tuple of 4 elements for IPv6
async def run():
sock = socket.socket(socket.AF_INET6)
with sock:
sock.setblocking(False)
# Check that sock_connect accepts 4-element address tuple
# for IPv6 sockets.
f = self.loop.sock_connect(sock, addr)
try:
await asyncio.wait_for(f, timeout=0.1)
except (asyncio.TimeoutError, ConnectionRefusedError):
# TimeoutError is expected.
pass
self.loop.run_until_complete(run())
def test_socket_ipv4_nameaddr(self):
async def run():
sock = socket.socket(socket.AF_INET)
with sock:
sock.setblocking(False)
await self.loop.sock_connect(sock, ('localhost', 0))
with self.assertRaises(OSError):
# Regression test: sock_connect(sock) wasn't calling
# getaddrinfo() with `family=sock.family`, which resulted
# in `socket.connect()` being called with an IPv6 address
# for IPv4 sockets, which used to cause a TypeError.
# Here we expect that that is fixed so we should get an
# OSError instead.
self.loop.run_until_complete(run())
def test_socket_blocking_error(self):
self.loop.set_debug(True)
sock = socket.socket()
with sock:
with self.assertRaisesRegex(ValueError, 'must be non-blocking'):
self.loop.run_until_complete(
self.loop.sock_recv(sock, 0))
with self.assertRaisesRegex(ValueError, 'must be non-blocking'):
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b''))
with self.assertRaisesRegex(ValueError, 'must be non-blocking'):
self.loop.run_until_complete(
self.loop.sock_accept(sock))
with self.assertRaisesRegex(ValueError, 'must be non-blocking'):
self.loop.run_until_complete(
self.loop.sock_connect(sock, (b'', 0)))
def test_socket_fileno(self):
rsock, wsock = socket.socketpair()
f = asyncio.Future(loop=self.loop)
def reader():
rsock.recv(100)
# We are done: unregister the file descriptor
self.loop.remove_reader(rsock)
f.set_result(None)
def writer():
wsock.send(b'abc')
self.loop.remove_writer(wsock)
with rsock, wsock:
self.loop.add_reader(rsock, reader)
self.loop.add_writer(wsock, writer)
self.loop.run_until_complete(f)
def test_socket_sync_remove_and_immediately_close(self):
# Test that it's OK to close the socket right after calling
# `remove_reader`.
sock = socket.socket()
with sock:
cb = lambda: None
sock.bind(('127.0.0.1', 0))
sock.listen(0)
fd = sock.fileno()
self.loop.add_reader(fd, cb)
self.loop.run_until_complete(asyncio.sleep(0.01))
self.loop.remove_reader(fd)
sock.close()
self.assertEqual(sock.fileno(), -1)
self.loop.run_until_complete(asyncio.sleep(0.01))
def test_sock_cancel_add_reader_race(self):
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
# asyncio 3.8.x has a regression; fixed in 3.9.0
# tracked in https://bugs.python.org/issue30064
raise unittest.SkipTest()
srv_sock_conn = None
async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()))
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
srv_sock_conn.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with srv_sock_conn:
await fut
async def client(addr):
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[
asyncio.ensure_future(
self.loop.sock_recv(sock_client, 1)
)
],
timeout=1,
)
async def send_server_data():
# Wait a little bit to let reader future cancel and
# schedule the removal of the reader callback. Right after
# "rfut.cancel()" we will call "loop.sock_recv()", which
# will add a reader. This will make a race between
# remove- and add-reader.
await asyncio.sleep(0.1)
await self.loop.sock_sendall(srv_sock_conn, b'1')
self.loop.create_task(send_server_data())
for rfut in pending_read_futs:
rfut.cancel()
data = await self.loop.sock_recv(sock_client, 1)
self.assertEqual(data, b'1')
self.loop.run_until_complete(server())
def test_sock_send_before_cancel(self):
if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8):
# asyncio 3.8.x has a regression; fixed in 3.9.0
# tracked in https://bugs.python.org/issue30064
raise unittest.SkipTest()
srv_sock_conn = None
async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()))
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
with srv_sock_conn:
await fut
async def client(addr):
await asyncio.sleep(0.01)
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[
asyncio.ensure_future(
self.loop.sock_recv(sock_client, 1)
)
],
timeout=1,
)
# server can send the data in a random time, even before
# the previous result future has cancelled.
await self.loop.sock_sendall(srv_sock_conn, b'1')
for rfut in pending_read_futs:
rfut.cancel()
data = await self.loop.sock_recv(sock_client, 1)
self.assertEqual(data, b'1')
self.loop.run_until_complete(server())
class TestUVSockets(_TestSockets, tb.UVTestCase):
@unittest.skipUnless(hasattr(select, 'epoll'), 'Linux only test')
def test_socket_sync_remove(self):
# See https://github.com/MagicStack/uvloop/issues/61 for details
sock = socket.socket()
epoll = select.epoll.fromfd(self.loop._get_backend_id())
try:
cb = lambda: None
sock.bind(('127.0.0.1', 0))
sock.listen(0)
fd = sock.fileno()
self.loop.add_reader(fd, cb)
self.loop.run_until_complete(asyncio.sleep(0.01))
self.loop.remove_reader(fd)
with self.assertRaises(FileNotFoundError):
epoll.modify(fd, 0)
finally:
sock.close()
self.loop.close()
epoll.close()
def test_add_reader_or_writer_transport_fd(self):
def assert_raises():
return self.assertRaisesRegex(
RuntimeError,
r'File descriptor .* is used by transport')
async def runner():
tr, pr = await self.loop.create_connection(
lambda: asyncio.Protocol(), sock=rsock)
try:
cb = lambda: None
sock = tr.get_extra_info('socket')
with assert_raises():
self.loop.add_reader(sock, cb)
with assert_raises():
self.loop.add_reader(sock.fileno(), cb)
with assert_raises():
self.loop.remove_reader(sock)
with assert_raises():
self.loop.remove_reader(sock.fileno())
with assert_raises():
self.loop.add_writer(sock, cb)
with assert_raises():
self.loop.add_writer(sock.fileno(), cb)
with assert_raises():
self.loop.remove_writer(sock)
with assert_raises():
self.loop.remove_writer(sock.fileno())
finally:
tr.close()
rsock, wsock = socket.socketpair()
try:
self.loop.run_until_complete(runner())
finally:
rsock.close()
wsock.close()
def test_pseudosocket(self):
def assert_raises():
return self.assertRaisesRegex(
RuntimeError,
r'File descriptor .* is used by transport')
def test_pseudo(real_sock, pseudo_sock, *, is_dup=False):
self.assertIn('AF_UNIX', repr(pseudo_sock))
self.assertEqual(pseudo_sock.family, real_sock.family)
self.assertEqual(pseudo_sock.proto, real_sock.proto)
# Guard against SOCK_NONBLOCK bit in socket.type on Linux.
self.assertEqual(pseudo_sock.type & 0xf, real_sock.type & 0xf)
with self.assertRaises(TypeError):
pickle.dumps(pseudo_sock)
na_meths = {
'accept', 'connect', 'connect_ex', 'bind', 'listen',
'makefile', 'sendfile', 'close', 'detach', 'shutdown',
'sendmsg_afalg', 'sendmsg', 'sendto', 'send', 'sendall',
'recv_into', 'recvfrom_into', 'recvmsg_into', 'recvmsg',
'recvfrom', 'recv'
}
for methname in na_meths:
meth = getattr(pseudo_sock, methname)
with self.assertRaisesRegex(
TypeError,
r'.*not support ' + methname + r'\(\) method'):
meth()
eq_meths = {
'getsockname', 'getpeername', 'get_inheritable', 'gettimeout'
}
for methname in eq_meths:
pmeth = getattr(pseudo_sock, methname)
rmeth = getattr(real_sock, methname)
# Call 2x to check caching paths
self.assertEqual(pmeth(), rmeth())
self.assertEqual(pmeth(), rmeth())
self.assertEqual(
pseudo_sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR),
0)
if not is_dup:
self.assertEqual(pseudo_sock.fileno(), real_sock.fileno())
duped = pseudo_sock.dup()
with duped:
test_pseudo(duped, pseudo_sock, is_dup=True)
with self.assertRaises(TypeError):
with pseudo_sock:
pass
async def runner():
tr, pr = await self.loop.create_connection(
lambda: asyncio.Protocol(), sock=rsock)
try:
sock = tr.get_extra_info('socket')
test_pseudo(rsock, sock)
finally:
tr.close()
rsock, wsock = socket.socketpair()
try:
self.loop.run_until_complete(runner())
finally:
rsock.close()
wsock.close()
def test_socket_connect_and_close(self):
def srv_gen(sock):
sock.send(b'helo')
async def client(sock, addr):
f = asyncio.ensure_future(self.loop.sock_connect(sock, addr),
loop=self.loop)
self.loop.call_soon(sock.close)
await f
return 'ok'
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
r = self.loop.run_until_complete(client(sock, srv.addr))
self.assertEqual(r, 'ok')
def test_socket_recv_and_close(self):
def srv_gen(sock):
time.sleep(1.2)
sock.send(b'helo')
async def kill(sock):
await asyncio.sleep(0.2)
sock.close()
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
f = asyncio.ensure_future(self.loop.sock_recv(sock, 10),
loop=self.loop)
self.loop.create_task(kill(sock))
res = await f
self.assertEqual(sock.fileno(), -1)
return res
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
r = self.loop.run_until_complete(w)
self.assertEqual(r, b'helo')
def test_socket_recv_into_and_close(self):
def srv_gen(sock):
time.sleep(1.2)
sock.send(b'helo')
async def kill(sock):
await asyncio.sleep(0.2)
sock.close()
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
data = bytearray(10)
with memoryview(data) as buf:
f = asyncio.ensure_future(self.loop.sock_recv_into(sock, buf),
loop=self.loop)
self.loop.create_task(kill(sock))
rcvd = await f
data = data[:rcvd]
self.assertEqual(sock.fileno(), -1)
return bytes(data)
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
r = self.loop.run_until_complete(w)
self.assertEqual(r, b'helo')
def test_socket_send_and_close(self):
ok = False
def srv_gen(sock):
nonlocal ok
b = sock.recv_all(2)
if b == b'hi':
ok = True
sock.send(b'ii')
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
s2 = sock.dup() # Don't let it drop connection until `f` is done
with s2:
f = asyncio.ensure_future(self.loop.sock_sendall(sock, b'hi'),
loop=self.loop)
self.loop.call_soon(sock.close)
await f
return await self.loop.sock_recv(s2, 2)
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
r = self.loop.run_until_complete(client(sock, srv.addr))
self.assertEqual(r, b'ii')
self.assertTrue(ok)
def test_socket_close_loop_and_close(self):
class Abort(Exception):
pass
def srv_gen(sock):
time.sleep(1.2)
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
asyncio.ensure_future(self.loop.sock_recv(sock, 10),
loop=self.loop)
await asyncio.sleep(0.2)
raise Abort
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
try:
sock = self.loop.run_until_complete(w)
except Abort:
pass
# `loop` still owns `sock`, so closing `sock` shouldn't
# do anything.
sock.close()
self.assertNotEqual(sock.fileno(), -1)
# `loop.close()` should io-decref all sockets that the
# loop owns, including our `sock`.
self.loop.close()
self.assertEqual(sock.fileno(), -1)
def test_socket_close_remove_reader(self):
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_reader(s, lambda: None)
self.loop.remove_reader(s.fileno())
s.close()
self.assertEqual(s.fileno(), -1)
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_reader(s.fileno(), lambda: None)
self.loop.remove_reader(s)
self.assertNotEqual(s.fileno(), -1)
s.close()
self.assertEqual(s.fileno(), -1)
def test_socket_close_remove_writer(self):
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_writer(s, lambda: None)
self.loop.remove_writer(s.fileno())
s.close()
self.assertEqual(s.fileno(), -1)
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_writer(s.fileno(), lambda: None)
self.loop.remove_writer(s)
self.assertNotEqual(s.fileno(), -1)
s.close()
self.assertEqual(s.fileno(), -1)
def test_socket_cancel_sock_recv_1(self):
def srv_gen(sock):
time.sleep(1.2)
sock.send(b'helo')
async def kill(fut):
await asyncio.sleep(0.2)
fut.cancel()
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
f = asyncio.ensure_future(self.loop.sock_recv(sock, 10),
loop=self.loop)
self.loop.create_task(kill(f))
with self.assertRaises(asyncio.CancelledError):
await f
sock.close()
self.assertEqual(sock.fileno(), -1)
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
self.loop.run_until_complete(w)
def test_socket_cancel_sock_recv_2(self):
def srv_gen(sock):
time.sleep(1.2)
sock.send(b'helo')
async def kill(fut):
await asyncio.sleep(0.5)
fut.cancel()
async def recv(sock):
fut = self.loop.create_task(self.loop.sock_recv(sock, 10))
await asyncio.sleep(0.1)
self.loop.remove_reader(sock)
sock.close()
try:
await fut
except asyncio.CancelledError:
raise
finally:
sock.close()
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
f = asyncio.ensure_future(recv(sock))
self.loop.create_task(kill(f))
with self.assertRaises(asyncio.CancelledError):
await f
sock.close()
self.assertEqual(sock.fileno(), -1)
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
self.loop.run_until_complete(w)
def test_socket_cancel_sock_sendall(self):
def srv_gen(sock):
time.sleep(1.2)
sock.recv_all(4)
async def kill(fut):
await asyncio.sleep(0.2)
fut.cancel()
async def client(sock, addr):
await self.loop.sock_connect(sock, addr)
f = asyncio.ensure_future(
self.loop.sock_sendall(sock, b'helo' * (1024 * 1024 * 50)),
loop=self.loop)
self.loop.create_task(kill(f))
with self.assertRaises(asyncio.CancelledError):
await f
sock.close()
self.assertEqual(sock.fileno(), -1)
# disable slow callback reporting for this test
self.loop.slow_callback_duration = 1000.0
with self.tcp_server(srv_gen) as srv:
sock = socket.socket()
with sock:
sock.setblocking(False)
c = client(sock, srv.addr)
w = asyncio.wait_for(c, timeout=5.0)
self.loop.run_until_complete(w)
def test_socket_close_many_add_readers(self):
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_reader(s, lambda: None)
self.loop.add_reader(s, lambda: None)
self.loop.add_reader(s, lambda: None)
self.loop.remove_reader(s.fileno())
s.close()
self.assertEqual(s.fileno(), -1)
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_reader(s, lambda: None)
self.loop.add_reader(s, lambda: None)
self.loop.add_reader(s, lambda: None)
self.loop.remove_reader(s)
s.close()
self.assertEqual(s.fileno(), -1)
def test_socket_close_many_remove_writers(self):
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_writer(s, lambda: None)
self.loop.add_writer(s, lambda: None)
self.loop.add_writer(s, lambda: None)
self.loop.remove_writer(s.fileno())
s.close()
self.assertEqual(s.fileno(), -1)
s = socket.socket()
with s:
s.setblocking(False)
self.loop.add_writer(s, lambda: None)
self.loop.add_writer(s, lambda: None)
self.loop.add_writer(s, lambda: None)
self.loop.remove_writer(s)
s.close()
self.assertEqual(s.fileno(), -1)
class TestAIOSockets(_TestSockets, tb.AIOTestCase):
pass