From 135d0606900dc2d0ad871563002de6712e6e6b78 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 21 Nov 2017 11:02:15 -0500 Subject: [PATCH] create_server() now makes a strong ref to the Server object. Fixes #81. Also makes Server objects weak-refable. --- tests/test_tcp.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ uvloop/loop.pxd | 2 ++ uvloop/loop.pyx | 3 +++ uvloop/server.pxd | 4 ++++ uvloop/server.pyx | 34 ++++++++++++++++++++++++---------- 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 4462575..ac0c9dc 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -6,6 +6,7 @@ import uvloop import ssl import sys import threading +import weakref from uvloop import _testbase as tb @@ -262,6 +263,50 @@ class _TestTCP: self.loop.run_until_complete(runner()) + def test_create_server_7(self): + # Test that create_server() stores a hard ref to the server object + # somewhere in the loop. In asyncio it so happens that + # loop.sock_accept() has a reference to the server object so it + # never gets GCed. + + class Proto(asyncio.Protocol): + def connection_made(self, tr): + self.tr = tr + self.tr.write(b'hello') + + async def test(): + port = tb.find_free_port() + srv = await self.loop.create_server(Proto, '127.0.0.1', port) + wsrv = weakref.ref(srv) + del srv + + gc.collect() + gc.collect() + gc.collect() + + s = socket.socket(socket.AF_INET) + with s: + s.setblocking(False) + await self.loop.sock_connect(s, ('127.0.0.1', port)) + d = await self.loop.sock_recv(s, 100) + self.assertEqual(d, b'hello') + + srv = wsrv() + srv.close() + await srv.wait_closed() + del srv + + # Let all transports shutdown. + await asyncio.sleep(0.1, loop=self.loop) + + gc.collect() + gc.collect() + gc.collect() + + self.assertIsNone(wsrv()) + + self.loop.run_until_complete(test()) + def test_create_connection_1(self): CNT = 0 TOTAL_CNT = 100 diff --git a/uvloop/loop.pxd b/uvloop/loop.pxd index 9d05d87..e22451c 100644 --- a/uvloop/loop.pxd +++ b/uvloop/loop.pxd @@ -52,6 +52,8 @@ cdef class Loop: set _queued_streams Py_ssize_t _ready_len + set _servers + object _transports dict _fd_to_reader_fileobj dict _fd_to_writer_fileobj diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 5e53add..5d66b79 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -148,6 +148,8 @@ cdef class Loop: # Set to True when `loop.shutdown_asyncgens` is called. self._asyncgens_shutdown_called = False + self._servers = set() + def __init__(self): self.set_debug((not sys_ignore_environment and bool(os_environ.get('PYTHONASYNCIODEBUG')))) @@ -1509,6 +1511,7 @@ cdef class Loop: server._add_server(tcp) + server._ref() return server async def create_connection(self, protocol_factory, host=None, port=None, *, diff --git a/uvloop/server.pxd b/uvloop/server.pxd index a0808a3..1555202 100644 --- a/uvloop/server.pxd +++ b/uvloop/server.pxd @@ -4,9 +4,13 @@ cdef class Server: list _waiters int _active_count Loop _loop + object __weakref__ cdef _add_server(self, UVStreamServer srv) cdef _wakeup(self) cdef _attach(self) cdef _detach(self) + + cdef _ref(self) + cdef _unref(self) diff --git a/uvloop/server.pyx b/uvloop/server.pyx index 6286400..b79e342 100644 --- a/uvloop/server.pyx +++ b/uvloop/server.pyx @@ -27,6 +27,13 @@ cdef class Server: if self._active_count == 0 and self._servers is None: self._wakeup() + cdef _ref(self): + # Keep the server object alive while it's not explicitly closed. + self._loop._servers.add(self) + + cdef _unref(self): + self._loop._servers.discard(self) + # Public API def __repr__(self): @@ -40,25 +47,32 @@ cdef class Server: await waiter def close(self): + cdef list servers + if self._servers is None: return - cdef list servers = self._servers - self._servers = None + try: + servers = self._servers + self._servers = None - for server in servers: - (server)._close() + for server in servers: + (server)._close() - if self._active_count == 0: - self._wakeup() + if self._active_count == 0: + self._wakeup() + finally: + self._unref() property sockets: def __get__(self): cdef list sockets = [] - for server in self._servers: - sockets.append( - (server)._get_socket() - ) + # Guard against `self._servers is None` + if self._servers: + for server in self._servers: + sockets.append( + (server)._get_socket() + ) return sockets