uvloop/tests/test_unix.py

525 lines
15 KiB
Python

import asyncio
import os
import socket
import tempfile
from uvloop import _testbase as tb
class _TestUnix:
def test_create_unix_server_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 100 # total number of clients that test will create
TIMEOUT = 5.0 # timeout for this test
async def handle_client(reader, writer):
nonlocal CNT
data = await reader.readexactly(4)
self.assertEqual(data, b'AAAA')
writer.write(b'OK')
data = await reader.readexactly(4)
self.assertEqual(data, b'BBBB')
writer.write(b'SPAM')
await writer.drain()
writer.close()
CNT += 1
async def test_client(addr):
sock = socket.socket(socket.AF_UNIX)
with sock:
sock.setblocking(False)
await self.loop.sock_connect(sock, addr)
await self.loop.sock_sendall(sock, b'AAAA')
buf = b''
while len(buf) != 2:
buf += await self.loop.sock_recv(sock, 1)
self.assertEqual(buf, b'OK')
await self.loop.sock_sendall(sock, b'BBBB')
buf = b''
while len(buf) != 4:
buf += await self.loop.sock_recv(sock, 1)
self.assertEqual(buf, b'SPAM')
async def start_server():
nonlocal CNT
CNT = 0
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
srv = await asyncio.start_unix_server(
handle_client,
sock_name,
loop=self.loop)
try:
srv_socks = srv.sockets
self.assertTrue(srv_socks)
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(
asyncio.gather(*tasks, loop=self.loop),
TIMEOUT, loop=self.loop)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
# Check that the server cleaned-up proxy-sockets
for srv_sock in srv_socks:
self.assertEqual(srv_sock.fileno(), -1)
# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
async def start_server_sock():
nonlocal CNT
CNT = 0
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
sock = socket.socket(socket.AF_UNIX)
sock.bind(sock_name)
srv = await asyncio.start_unix_server(
handle_client,
None,
loop=self.loop,
sock=sock)
try:
srv_socks = srv.sockets
self.assertTrue(srv_socks)
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(
asyncio.gather(*tasks, loop=self.loop),
TIMEOUT, loop=self.loop)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
# Check that the server cleaned-up proxy-sockets
for srv_sock in srv_socks:
self.assertEqual(srv_sock.fileno(), -1)
# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
self.loop.run_until_complete(start_server())
self.assertEqual(CNT, TOTAL_CNT)
self.loop.run_until_complete(start_server_sock())
self.assertEqual(CNT, TOTAL_CNT)
def test_create_unix_server_2(self):
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
with open(sock_name, 'wt') as f:
f.write('x')
with self.assertRaisesRegex(
OSError, "Address '{}' is already in use".format(
sock_name)):
self.loop.run_until_complete(
self.loop.create_unix_server(object, sock_name))
def test_create_unix_connection_1(self):
CNT = 0
TOTAL_CNT = 100
def server():
data = yield tb.read(4)
self.assertEqual(data, b'AAAA')
yield tb.write(b'OK')
data = yield tb.read(4)
self.assertEqual(data, b'BBBB')
yield tb.write(b'SPAM')
async def client(addr):
reader, writer = await asyncio.open_unix_connection(
addr,
loop=self.loop)
writer.write(b'AAAA')
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(b'BBBB')
self.assertEqual(await reader.readexactly(4), b'SPAM')
nonlocal CNT
CNT += 1
writer.close()
async def client_2(addr):
sock = socket.socket(socket.AF_UNIX)
sock.connect(addr)
reader, writer = await asyncio.open_unix_connection(
sock=sock,
loop=self.loop)
writer.write(b'AAAA')
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(b'BBBB')
self.assertEqual(await reader.readexactly(4), b'SPAM')
nonlocal CNT
CNT += 1
writer.close()
def run(coro):
nonlocal CNT
CNT = 0
srv = tb.tcp_server(server,
family=socket.AF_UNIX,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT)
srv.start()
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(
asyncio.gather(*tasks, loop=self.loop))
srv.join()
self.assertEqual(CNT, TOTAL_CNT)
run(client)
run(client_2)
def test_create_unix_connection_2(self):
with tempfile.NamedTemporaryFile() as tmp:
path = tmp.name
async def client():
reader, writer = await asyncio.open_unix_connection(
path,
loop=self.loop)
async def runner():
with self.assertRaises(FileNotFoundError):
await client()
self.loop.run_until_complete(runner())
def test_create_unix_connection_3(self):
CNT = 0
TOTAL_CNT = 100
def server():
data = yield tb.read(4)
self.assertEqual(data, b'AAAA')
yield tb.close()
async def client(addr):
reader, writer = await asyncio.open_unix_connection(
addr,
loop=self.loop)
sock = writer._transport.get_extra_info('socket')
self.assertEqual(sock.family, socket.AF_UNIX)
writer.write(b'AAAA')
with self.assertRaises(asyncio.IncompleteReadError):
await reader.readexactly(10)
writer.close()
nonlocal CNT
CNT += 1
def run(coro):
nonlocal CNT
CNT = 0
srv = tb.tcp_server(server,
family=socket.AF_UNIX,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT)
srv.start()
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(
asyncio.gather(*tasks, loop=self.loop))
srv.join()
self.assertEqual(CNT, TOTAL_CNT)
run(client)
def test_create_unix_connection_4(self):
sock = socket.socket(socket.AF_UNIX)
sock.close()
async def client():
reader, writer = await asyncio.open_unix_connection(
sock=sock,
loop=self.loop)
async def runner():
with self.assertRaisesRegex(OSError, 'Bad file'):
await client()
self.loop.run_until_complete(runner())
def test_create_unix_connection_5(self):
s1, s2 = socket.socketpair(socket.AF_UNIX)
excs = []
class Proto(asyncio.Protocol):
def connection_lost(self, exc):
excs.append(exc)
proto = Proto()
async def client():
t, _ = await self.loop.create_unix_connection(
lambda: proto,
None,
sock=s2)
t.write(b'AAAAA')
s1.close()
t.write(b'AAAAA')
await asyncio.sleep(0.1, loop=self.loop)
self.loop.run_until_complete(client())
self.assertEqual(len(excs), 1)
self.assertIn(excs[0].__class__,
(BrokenPipeError, ConnectionResetError))
def test_transport_unclosed_warning(self):
async def test(sock):
return await self.loop.create_unix_connection(
asyncio.Protocol,
None,
sock=sock)
with self.assertWarnsRegex(ResourceWarning, 'unclosed'):
s1, s2 = socket.socketpair(socket.AF_UNIX)
with s1, s2:
self.loop.run_until_complete(test(s1))
self.loop.close()
class Test_UV_Unix(_TestUnix, tb.UVTestCase):
def test_transport_fromsock_get_extra_info(self):
# This tests is only for uvloop. asyncio should pass it
# too in Python 3.6.
async def test(sock):
t, _ = await self.loop.create_unix_connection(
asyncio.Protocol,
None,
sock=sock)
sock = t.get_extra_info('socket')
self.assertIs(t.get_extra_info('socket'), sock)
with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
self.loop.add_writer(sock.fileno(), lambda: None)
with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
self.loop.remove_writer(sock.fileno())
t.close()
s1, s2 = socket.socketpair(socket.AF_UNIX)
with s1, s2:
self.loop.run_until_complete(test(s1))
class Test_AIO_Unix(_TestUnix, tb.AIOTestCase):
pass
class _TestSSL(tb.SSLTestCase):
def test_create_unix_server_ssl_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
TIMEOUT = 5.0 # timeout for this test
A_DATA = b'A' * 1024 * 1024
B_DATA = b'B' * 1024 * 1024
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
clients = []
async def handle_client(reader, writer):
nonlocal CNT
data = await reader.readexactly(len(A_DATA))
self.assertEqual(data, A_DATA)
writer.write(b'OK')
data = await reader.readexactly(len(B_DATA))
self.assertEqual(data, B_DATA)
writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
await writer.drain()
writer.close()
CNT += 1
async def test_client(addr):
fut = asyncio.Future(loop=self.loop)
def prog():
try:
yield tb.starttls(client_sslctx)
yield tb.connect(addr)
yield tb.write(A_DATA)
data = yield tb.read(2)
self.assertEqual(data, b'OK')
yield tb.write(B_DATA)
data = yield tb.read(4)
self.assertEqual(data, b'SPAM')
yield tb.close()
except Exception as ex:
self.loop.call_soon_threadsafe(fut.set_exception, ex)
else:
self.loop.call_soon_threadsafe(fut.set_result, None)
client = tb.tcp_client(prog, family=socket.AF_UNIX)
client.start()
clients.append(client)
await fut
async def start_server():
with tempfile.TemporaryDirectory() as td:
sock_name = os.path.join(td, 'sock')
srv = await asyncio.start_unix_server(
handle_client,
sock_name,
ssl=sslctx,
loop=self.loop)
try:
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(test_client(sock_name))
await asyncio.wait_for(
asyncio.gather(*tasks, loop=self.loop),
TIMEOUT, loop=self.loop)
finally:
self.loop.call_soon(srv.close)
await srv.wait_closed()
with self._silence_eof_received_warning():
self.loop.run_until_complete(start_server())
self.assertEqual(CNT, TOTAL_CNT)
for client in clients:
client.stop()
def test_create_unix_connection_ssl_1(self):
CNT = 0
TOTAL_CNT = 25
A_DATA = b'A' * 1024 * 1024
B_DATA = b'B' * 1024 * 1024
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
def server():
yield tb.starttls(
sslctx,
server_side=True)
data = yield tb.read(len(A_DATA))
self.assertEqual(data, A_DATA)
yield tb.write(b'OK')
data = yield tb.read(len(B_DATA))
self.assertEqual(data, B_DATA)
yield tb.write(b'SPAM')
yield tb.close()
async def client(addr):
reader, writer = await asyncio.open_unix_connection(
addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop)
writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK')
writer.write(B_DATA)
self.assertEqual(await reader.readexactly(4), b'SPAM')
nonlocal CNT
CNT += 1
writer.close()
def run(coro):
nonlocal CNT
CNT = 0
srv = tb.tcp_server(server,
family=socket.AF_UNIX,
max_clients=TOTAL_CNT,
backlog=TOTAL_CNT)
srv.start()
tasks = []
for _ in range(TOTAL_CNT):
tasks.append(coro(srv.addr))
self.loop.run_until_complete(
asyncio.gather(*tasks, loop=self.loop))
srv.join()
self.assertEqual(CNT, TOTAL_CNT)
with self._silence_eof_received_warning():
run(client)
class Test_UV_UnixSSL(_TestSSL, tb.UVTestCase):
pass
class Test_AIO_UnixSSL(_TestSSL, tb.AIOTestCase):
pass