uvloop/tests/test_unix.py

676 lines
21 KiB
Python

import asyncio
import os
import pathlib
import socket
import tempfile
import time
import unittest
import sys
from uvloop import _testbase as tb
class _TestUnix:
def test_create_unix_server_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 100 # total number of clients that test will create
TIMEOUT = 5.0 # timeout for this test
async def handle_client(reader, writer):
nonlocal CNT
data = await reader.readexactly(4)
self.assertEqual(data, b'AAAA')
writer.write(b'OK')
data = await reader.readexactly(4)
self.assertEqual(data, b'BBBB')
writer.write(b'SPAM')
await writer.drain()
writer.close()
await self.wait_closed(writer)
CNT += 1
async def test_client(addr):
sock = socket.socket(socket.AF_UNIX)
with sock:
sock.setblocking(False)
await self.loop.sock_connect(sock, addr)
await self.loop.sock_sendall(sock, b'AAAA')
buf = b''
while len(buf) != 2:
buf += await self.loop.sock_recv(sock, 1)
self.assertEqual(buf, b'OK')
await self.loop.sock_sendall(sock, b'BBBB')
buf = b''
while len(buf) != 4:
buf += await self.loop.sock_recv(sock, 1)
self.assertEqual(buf, b'SPAM')
async def start_server():
nonlocal CNT
CNT = 0
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
srv = await asyncio.start_unix_server(
handle_client,
sock_name)
try:
srv_socks = srv.sockets
self.assertTrue(srv_socks)
self.assertTrue(srv.is_serving())
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
# Check that the server cleaned-up proxy-sockets
for srv_sock in srv_socks:
self.assertEqual(srv_sock.fileno(), -1)
self.assertFalse(srv.is_serving())
# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
async def start_server_sock(start_server):
nonlocal CNT
CNT = 0
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
sock = socket.socket(socket.AF_UNIX)
sock.bind(sock_name)
srv = await start_server(sock)
try:
srv_socks = srv.sockets
self.assertTrue(srv_socks)
self.assertTrue(srv.is_serving())
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
# Check that the server cleaned-up proxy-sockets
for srv_sock in srv_socks:
self.assertEqual(srv_sock.fileno(), -1)
self.assertFalse(srv.is_serving())
# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
with self.subTest(func='start_unix_server(host, port)'):
self.loop.run_until_complete(start_server())
self.assertEqual(CNT, TOTAL_CNT)
with self.subTest(func='start_unix_server(sock)'):
self.loop.run_until_complete(start_server_sock(
lambda sock: asyncio.start_unix_server(
handle_client,
None,
sock=sock)))
self.assertEqual(CNT, TOTAL_CNT)
with self.subTest(func='start_server(sock)'):
self.loop.run_until_complete(start_server_sock(
lambda sock: asyncio.start_server(
handle_client,
None, None,
sock=sock)))
self.assertEqual(CNT, TOTAL_CNT)
def test_create_unix_server_2(self):
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
with open(sock_name, 'wt') as f:
f.write('x')
with self.assertRaisesRegex(
OSError, "Address '{}' is already in use".format(
sock_name)):
self.loop.run_until_complete(
self.loop.create_unix_server(object, sock_name))
def test_create_unix_server_3(self):
with self.assertRaisesRegex(
ValueError, 'ssl_handshake_timeout is only meaningful'):
self.loop.run_until_complete(
self.loop.create_unix_server(
lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
def test_create_unix_server_existing_path_sock(self):
with self.unix_sock_name() as path:
sock = socket.socket(socket.AF_UNIX)
with sock:
sock.bind(path)
sock.listen(1)
# Check that no error is raised -- `path` is removed.
coro = self.loop.create_unix_server(lambda: None, path)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
def test_create_unix_connection_open_unix_con_addr(self):
async def client(addr):
reader, writer = await asyncio.open_unix_connection(addr)
writer.write(b'AAAA')
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(b'BBBB')
self.assertEqual(await reader.readexactly(4), b'SPAM')
writer.close()
await self.wait_closed(writer)
self._test_create_unix_connection_1(client)
def test_create_unix_connection_open_unix_con_sock(self):
async def client(addr):
sock = socket.socket(socket.AF_UNIX)
sock.connect(addr)
reader, writer = await asyncio.open_unix_connection(sock=sock)
writer.write(b'AAAA')
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(b'BBBB')
self.assertEqual(await reader.readexactly(4), b'SPAM')
writer.close()
await self.wait_closed(writer)
self._test_create_unix_connection_1(client)
def test_create_unix_connection_open_con_sock(self):
async def client(addr):
sock = socket.socket(socket.AF_UNIX)
sock.connect(addr)
reader, writer = await asyncio.open_connection(sock=sock)
writer.write(b'AAAA')
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(b'BBBB')
self.assertEqual(await reader.readexactly(4), b'SPAM')
writer.close()
await self.wait_closed(writer)
self._test_create_unix_connection_1(client)
def _test_create_unix_connection_1(self, client):
CNT = 0
TOTAL_CNT = 100
def server(sock):
data = sock.recv_all(4)
self.assertEqual(data, b'AAAA')
sock.send(b'OK')
data = sock.recv_all(4)
self.assertEqual(data, b'BBBB')
sock.send(b'SPAM')
async def client_wrapper(addr):
await client(addr)
nonlocal CNT
CNT += 1
def run(coro):
nonlocal CNT
CNT = 0
with self.unix_server(server,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT) as srv:
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(asyncio.gather(*tasks))
# Give time for all transports to close.
self.loop.run_until_complete(asyncio.sleep(0.1))
self.assertEqual(CNT, TOTAL_CNT)
run(client_wrapper)
def test_create_unix_connection_2(self):
with tempfile.NamedTemporaryFile() as tmp:
path = tmp.name
async def client():
reader, writer = await asyncio.open_unix_connection(path)
writer.close()
await self.wait_closed(writer)
async def runner():
with self.assertRaises(FileNotFoundError):
await client()
self.loop.run_until_complete(runner())
def test_create_unix_connection_3(self):
CNT = 0
TOTAL_CNT = 100
def server(sock):
data = sock.recv_all(4)
self.assertEqual(data, b'AAAA')
sock.close()
async def client(addr):
reader, writer = await asyncio.open_unix_connection(addr)
sock = writer._transport.get_extra_info('socket')
self.assertEqual(sock.family, socket.AF_UNIX)
writer.write(b'AAAA')
with self.assertRaises(asyncio.IncompleteReadError):
await reader.readexactly(10)
writer.close()
await self.wait_closed(writer)
nonlocal CNT
CNT += 1
def run(coro):
nonlocal CNT
CNT = 0
with self.unix_server(server,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT) as srv:
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(asyncio.gather(*tasks))
self.assertEqual(CNT, TOTAL_CNT)
run(client)
def test_create_unix_connection_4(self):
sock = socket.socket(socket.AF_UNIX)
sock.close()
async def client():
reader, writer = await asyncio.open_unix_connection(sock=sock)
writer.close()
await self.wait_closed(writer)
async def runner():
with self.assertRaisesRegex(OSError, 'Bad file'):
await client()
self.loop.run_until_complete(runner())
def test_create_unix_connection_5(self):
s1, s2 = socket.socketpair(socket.AF_UNIX)
excs = []
class Proto(asyncio.Protocol):
def connection_lost(self, exc):
excs.append(exc)
proto = Proto()
async def client():
t, _ = await self.loop.create_unix_connection(
lambda: proto,
None,
sock=s2)
t.write(b'AAAAA')
s1.close()
t.write(b'AAAAA')
await asyncio.sleep(0.1)
self.loop.run_until_complete(client())
self.assertEqual(len(excs), 1)
self.assertIn(excs[0].__class__,
(BrokenPipeError, ConnectionResetError))
def test_create_unix_connection_6(self):
with self.assertRaisesRegex(
ValueError, 'ssl_handshake_timeout is only meaningful'):
self.loop.run_until_complete(
self.loop.create_unix_connection(
lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
class Test_UV_Unix(_TestUnix, tb.UVTestCase):
@unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()')
def test_create_unix_connection_pathlib(self):
async def run(addr):
t, _ = await self.loop.create_unix_connection(
asyncio.Protocol, addr)
t.close()
with self.unix_server(lambda sock: time.sleep(0.01)) as srv:
addr = pathlib.Path(srv.addr)
self.loop.run_until_complete(run(addr))
@unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()')
def test_create_unix_server_pathlib(self):
with self.unix_sock_name() as srv_path:
srv_path = pathlib.Path(srv_path)
srv = self.loop.run_until_complete(
self.loop.create_unix_server(asyncio.Protocol, srv_path))
srv.close()
self.loop.run_until_complete(srv.wait_closed())
def test_transport_fromsock_get_extra_info(self):
# This tests is only for uvloop. asyncio should pass it
# too in Python 3.6.
async def test(sock):
t, _ = await self.loop.create_unix_connection(
asyncio.Protocol,
sock=sock)
sock = t.get_extra_info('socket')
self.assertIs(t.get_extra_info('socket'), sock)
with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
self.loop.add_writer(sock.fileno(), lambda: None)
with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
self.loop.remove_writer(sock.fileno())
t.close()
s1, s2 = socket.socketpair(socket.AF_UNIX)
with s1, s2:
self.loop.run_until_complete(test(s1))
def test_create_unix_server_path_dgram(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
with sock:
coro = self.loop.create_unix_server(lambda: None,
sock=sock)
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Stream.*was expected'):
self.loop.run_until_complete(coro)
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)')
def test_create_unix_server_path_stream_bittype(self):
sock = socket.socket(
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with tempfile.NamedTemporaryFile() as file:
fn = file.name
try:
with sock:
sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
finally:
os.unlink(fn)
@unittest.skipUnless(sys.platform.startswith('linux'), 'requires epoll')
def test_epollhup(self):
SIZE = 50
eof = False
done = False
recvd = b''
class Proto(asyncio.BaseProtocol):
def connection_made(self, tr):
tr.write(b'hello')
self.data = bytearray(SIZE)
self.buf = memoryview(self.data)
def get_buffer(self, sizehint):
return self.buf
def buffer_updated(self, nbytes):
nonlocal recvd
recvd += self.buf[:nbytes]
def eof_received(self):
nonlocal eof
eof = True
def connection_lost(self, exc):
nonlocal done
done = exc
async def test():
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
srv = await self.loop.create_unix_server(Proto, sock_name)
s = socket.socket(socket.AF_UNIX)
with s:
s.setblocking(False)
await self.loop.sock_connect(s, sock_name)
d = await self.loop.sock_recv(s, 100)
self.assertEqual(d, b'hello')
# IMPORTANT: overflow recv buffer and close immediately
await self.loop.sock_sendall(s, b'a' * (SIZE + 1))
srv.close()
await srv.wait_closed()
self.loop.run_until_complete(test())
self.assertTrue(eof)
self.assertIsNone(done)
self.assertEqual(recvd, b'a' * (SIZE + 1))
class Test_AIO_Unix(_TestUnix, tb.AIOTestCase):
pass
class _TestSSL(tb.SSLTestCase):
ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')
def test_create_unix_server_ssl_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
TIMEOUT = 10.0 # timeout for this test
A_DATA = b'A' * 1024 * 1024
B_DATA = b'B' * 1024 * 1024
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
clients = []
async def handle_client(reader, writer):
nonlocal CNT
data = await reader.readexactly(len(A_DATA))
self.assertEqual(data, A_DATA)
writer.write(b'OK')
data = await reader.readexactly(len(B_DATA))
self.assertEqual(data, B_DATA)
writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
await writer.drain()
writer.close()
CNT += 1
async def test_client(addr):
fut = asyncio.Future(loop=self.loop)
def prog(sock):
try:
sock.starttls(client_sslctx)
sock.connect(addr)
sock.send(A_DATA)
data = sock.recv_all(2)
self.assertEqual(data, b'OK')
sock.send(B_DATA)
data = sock.recv_all(4)
self.assertEqual(data, b'SPAM')
sock.close()
except Exception as ex:
self.loop.call_soon_threadsafe(
lambda ex=ex:
(fut.cancelled() or fut.set_exception(ex)))
else:
self.loop.call_soon_threadsafe(
lambda: (fut.cancelled() or fut.set_result(None)))
client = self.unix_client(prog)
client.start()
clients.append(client)
await fut
async def start_server():
extras = dict(ssl_handshake_timeout=10.0)
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
srv = await asyncio.start_unix_server(
handle_client,
sock_name,
ssl=sslctx,
**extras)
try:
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
try:
with self._silence_eof_received_warning():
self.loop.run_until_complete(start_server())
except asyncio.TimeoutError:
if os.environ.get('TRAVIS_OS_NAME') == 'osx':
# XXX: figure out why this fails on macOS on Travis
raise unittest.SkipTest('unexplained error on Travis macOS')
else:
raise
self.assertEqual(CNT, TOTAL_CNT)
for client in clients:
client.stop()
def test_create_unix_connection_ssl_1(self):
CNT = 0
TOTAL_CNT = 25
A_DATA = b'A' * 1024 * 1024
B_DATA = b'B' * 1024 * 1024
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
def server(sock):
sock.starttls(sslctx, server_side=True)
data = sock.recv_all(len(A_DATA))
self.assertEqual(data, A_DATA)
sock.send(b'OK')
data = sock.recv_all(len(B_DATA))
self.assertEqual(data, B_DATA)
sock.send(b'SPAM')
sock.close()
async def client(addr):
extras = dict(ssl_handshake_timeout=10.0)
reader, writer = await asyncio.open_unix_connection(
addr,
ssl=client_sslctx,
server_hostname='',
**extras)
writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(B_DATA)
self.assertEqual(await reader.readexactly(4), b'SPAM')
nonlocal CNT
CNT += 1
writer.close()
await self.wait_closed(writer)
def run(coro):
nonlocal CNT
CNT = 0
with self.unix_server(server,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT) as srv:
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(asyncio.gather(*tasks))
self.assertEqual(CNT, TOTAL_CNT)
with self._silence_eof_received_warning():
run(client)
class Test_UV_UnixSSL(_TestSSL, tb.UVTestCase):
pass
class Test_AIO_UnixSSL(_TestSSL, tb.AIOTestCase):
pass