Implement pseudo-socket objects; don't dup() sockets

* `transport.get_extra_info('socket')` from now on will return a
  socket-like object.  That object will allow socket calls like
  `getsockopt()` and `setsockopt()` but will deny `close()`,
  `send()`, `recv()` and other invasive operations that might
  interfere with libuv.

* We no longer dup sockets when they are passed to APIs like
  `loop.create_server(sock=sock)`.  We now use `socket._io_refs`
  private API and new pseudo-sockets to guarantee that transports
  will work correctly even when users try to close the original
  socket object.
This commit is contained in:
Yury Selivanov 2017-11-24 16:46:19 -05:00
parent e19a233fde
commit 318e593e3f
11 changed files with 391 additions and 115 deletions

View File

@ -1,4 +1,5 @@
import asyncio
import pickle
import select
import socket
import sys
@ -228,6 +229,79 @@ class TestUVSockets(_TestSockets, tb.UVTestCase):
rsock.close()
wsock.close()
def test_pseudosocket(self):
def assert_raises():
return self.assertRaisesRegex(
RuntimeError,
r'File descriptor .* is used by transport')
def test_pseudo(real_sock, pseudo_sock, *, is_dup=False):
self.assertIn('AF_UNIX', repr(pseudo_sock))
self.assertEqual(pseudo_sock.family, real_sock.family)
self.assertEqual(pseudo_sock.type, real_sock.type)
self.assertEqual(pseudo_sock.proto, real_sock.proto)
with self.assertRaises(TypeError):
pickle.dumps(pseudo_sock)
na_meths = {
'accept', 'connect', 'connect_ex', 'bind', 'listen',
'makefile', 'sendfile', 'close', 'detach', 'shutdown',
'sendmsg_afalg', 'sendmsg', 'sendto', 'send', 'sendall',
'recv_into', 'recvfrom_into', 'recvmsg_into', 'recvmsg',
'recvfrom', 'recv'
}
for methname in na_meths:
meth = getattr(pseudo_sock, methname)
with self.assertRaisesRegex(
TypeError,
r'.*not support ' + methname + r'\(\) method'):
meth()
eq_meths = {
'getsockname', 'getpeername', 'get_inheritable', 'gettimeout'
}
for methname in eq_meths:
pmeth = getattr(pseudo_sock, methname)
rmeth = getattr(real_sock, methname)
# Call 2x to check caching paths
self.assertEqual(pmeth(), rmeth())
self.assertEqual(pmeth(), rmeth())
self.assertEqual(
pseudo_sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR),
0)
if not is_dup:
self.assertEqual(pseudo_sock.fileno(), real_sock.fileno())
duped = pseudo_sock.dup()
with duped:
test_pseudo(duped, pseudo_sock, is_dup=True)
with self.assertRaises(TypeError):
with pseudo_sock:
pass
async def runner():
tr, pr = await self.loop.create_connection(
lambda: asyncio.Protocol(), sock=rsock)
try:
sock = tr.get_extra_info('socket')
test_pseudo(rsock, sock)
finally:
tr.close()
rsock, wsock = socket.socketpair()
try:
self.loop.run_until_complete(runner())
finally:
rsock.close()
wsock.close()
def test_socket_connect_and_close(self):
def srv_gen(sock):
sock.send(b'helo')

View File

@ -619,7 +619,6 @@ class Test_UV_TCP(_TestTCP, tb.UVTestCase):
with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
self.loop.remove_reader(sock.fileno())
self.assertTrue(isinstance(sock, socket.socket))
self.assertEqual(t.get_extra_info('sockname'),
sockname)
self.assertEqual(t.get_extra_info('peername'),
@ -764,6 +763,7 @@ class Test_UV_TCP(_TestTCP, tb.UVTestCase):
with self.assertWarnsRegex(ResourceWarning, rx):
self.loop.create_task(run())
self.loop.run_until_complete(srv.wait_closed())
self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
srv = None
gc.collect()

View File

@ -22,6 +22,7 @@ cdef class UVHandle:
cdef inline _free(self)
cdef _close(self)
cdef _after_close(self)
cdef class UVSocketHandle(UVHandle):

View File

@ -209,6 +209,11 @@ cdef class UVHandle:
Py_INCREF(self)
uv.uv_close(self._handle, __uv_close_handle_cb) # void; no errors
cdef _after_close(self):
# Can only be called when '._close()' was called by hand
# (i.e. won't be called on UVHandle.__dealloc__).
pass
def __repr__(self):
return '<{} closed={} {:#x}>'.format(
self.__class__.__name__,
@ -255,28 +260,38 @@ cdef class UVSocketHandle(UVHandle):
# When we create a TCP/PIPE/etc connection/server based on
# a Python file object, we need to close the file object when
# the uv handle is closed.
socket_inc_io_ref(file)
self._fileobj = file
cdef _close(self):
try:
if self.__cached_socket is not None:
self.__cached_socket.detach()
self.__cached_socket = None
if self.__cached_socket is not None:
(<PseudoSocket>self.__cached_socket)._fd = -1
UVHandle._close(self)
cdef _after_close(self):
try:
# This code will only run for transports created from
# Python sockets, i.e. with `loop.create_server(sock=sock)` etc.
if self._fileobj is not None:
try:
socket_dec_io_ref(self._fileobj)
# `socket.close()` will raise an EBADF because libuv
# has already closed the underlying FDself.
self._fileobj.close()
except Exception as exc:
self._loop.call_exception_handler({
'exception': exc,
'transport': self,
'message': 'could not close attached file object {!r}'.
format(self._fileobj)
})
finally:
self._fileobj = None
except OSError as ex:
if ex.errno != errno_EBADF:
raise
except Exception as ex:
self._loop.call_exception_handler({
'exception': ex,
'transport': self,
'message': 'could not close attached file object {!r}'.
format(self._fileobj)
})
finally:
UVHandle._close(self)
self._fileobj = None
UVHandle._after_close(self)
cdef _open(self, int sockfd):
raise NotImplementedError
@ -332,11 +347,14 @@ cdef void __uv_close_handle_cb(uv.uv_handle_t* handle) with gil:
PyMem_RawFree(handle)
else:
h = <UVHandle>handle.data
if UVLOOP_DEBUG:
h._loop._debug_handles_closed.update([
h.__class__.__name__])
h._free()
Py_DECREF(h) # Was INCREFed in UVHandle._close
try:
if UVLOOP_DEBUG:
h._loop._debug_handles_closed.update([
h.__class__.__name__])
h._free()
h._after_close()
finally:
Py_DECREF(h) # Was INCREFed in UVHandle._close
cdef void __close_all_handles(Loop loop):

View File

@ -31,7 +31,7 @@ cdef __pipe_open(UVStream handle, int fd):
cdef __pipe_get_socket(UVSocketHandle handle):
fileno = handle._fileno()
return socket_socket(uv.AF_UNIX, uv.SOCK_STREAM, 0, fileno)
return PseudoSocket(uv.AF_UNIX, uv.SOCK_STREAM, 0, fileno)
@cython.no_gc_clear

View File

@ -50,7 +50,7 @@ cdef __tcp_get_socket(UVSocketHandle handle):
if err < 0:
raise convert_error(err)
return socket_socket(buf.ss_family, uv.SOCK_STREAM, 0, fileno)
return PseudoSocket(buf.ss_family, uv.SOCK_STREAM, 0, fileno)
@cython.no_gc_clear

View File

@ -188,7 +188,7 @@ cdef class UDPTransport(UVBaseTransport):
'UDPTransport.family is undefined; cannot create python socket')
fileno = self._fileno()
return socket_socket(self._family, uv.SOCK_STREAM, 0, fileno)
return PseudoSocket(self._family, uv.SOCK_STREAM, 0, fileno)
cdef _send(self, object data, object addr):
cdef:

View File

@ -50,6 +50,7 @@ cdef col_Counter = collections.Counter
cdef cc_ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor
cdef cc_Future = concurrent.futures.Future
cdef errno_EBADF = errno.EBADF
cdef errno_EINVAL = errno.EINVAL
cdef ft_partial = functools.partial
@ -68,6 +69,8 @@ cdef socket_timeout = socket.timeout
cdef socket_socket = socket.socket
cdef socket_socketpair = socket.socketpair
cdef socket_getservbyname = socket.getservbyname
cdef socket_AddressFamily = socket.AddressFamily
cdef socket_SocketKind = socket.SocketKind
cdef int socket_EAI_ADDRFAMILY = getattr(socket, 'EAI_ADDRFAMILY', -1)
cdef int socket_EAI_AGAIN = getattr(socket, 'EAI_AGAIN', -1)

View File

@ -166,9 +166,6 @@ cdef class Loop:
cdef _fileobj_to_fd(self, fileobj)
cdef _ensure_fd_no_transport(self, fd)
cdef _inc_io_ref(self, sock)
cdef _dec_io_ref(self, sock)
cdef _add_reader(self, fd, Handle handle)
cdef _remove_reader(self, fd)

View File

@ -59,6 +59,16 @@ cdef isfuture(obj):
return aio_isfuture(obj)
cdef inline socket_inc_io_ref(sock):
if isinstance(sock, socket_socket):
sock._io_refs += 1
cdef inline socket_dec_io_ref(sock):
if isinstance(sock, socket_socket):
sock._decref_socketios()
@cython.no_gc_clear
cdef class Loop:
def __cinit__(self):
@ -452,12 +462,12 @@ cdef class Loop:
if self._fd_to_writer_fileobj:
for fileobj in self._fd_to_writer_fileobj.values():
self._dec_io_ref(fileobj)
socket_dec_io_ref(fileobj)
self._fd_to_writer_fileobj.clear()
if self._fd_to_reader_fileobj:
for fileobj in self._fd_to_reader_fileobj.values():
self._dec_io_ref(fileobj)
socket_dec_io_ref(fileobj)
self._fd_to_reader_fileobj.clear()
if self._timers:
@ -615,21 +625,6 @@ cdef class Loop:
'File descriptor {!r} is used by transport {!r}'.format(
fd, tr))
cdef inline _inc_io_ref(self, sock):
try:
sock._io_refs += 1
except AttributeError:
pass
cdef inline _dec_io_ref(self, sock):
try:
sock._io_refs
sock._decref_socketios
except AttributeError:
pass
else:
sock._decref_socketios()
cdef _add_reader(self, fileobj, Handle handle):
cdef:
UVPoll poll
@ -646,7 +641,7 @@ cdef class Loop:
self._fd_to_reader_fileobj[fd] = fileobj
poll.start_reading(handle)
self._inc_io_ref(fileobj)
socket_inc_io_ref(fileobj)
cdef _remove_reader(self, fileobj):
cdef:
@ -655,7 +650,7 @@ cdef class Loop:
fd = self._fileobj_to_fd(fileobj)
self._ensure_fd_no_transport(fd)
self._fd_to_reader_fileobj.pop(fd, None)
self._dec_io_ref(fileobj)
socket_dec_io_ref(fileobj)
if self._closed == 1:
return False
@ -688,7 +683,7 @@ cdef class Loop:
self._fd_to_writer_fileobj[fd] = fileobj
poll.start_writing(handle)
self._inc_io_ref(fileobj)
socket_inc_io_ref(fileobj)
cdef _remove_writer(self, fileobj):
cdef:
@ -697,7 +692,7 @@ cdef class Loop:
fd = self._fileobj_to_fd(fileobj)
self._ensure_fd_no_transport(fd)
self._fd_to_writer_fileobj.pop(fd, None)
self._dec_io_ref(fileobj)
socket_dec_io_ref(fileobj)
if self._closed == 1:
return False
@ -1497,19 +1492,21 @@ cdef class Loop:
raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock))
# libuv will set the socket to non-blocking mode, but
# we want Python socket object to notice that.
sock.setblocking(False)
tcp = TCPServer.new(self, protocol_factory, server, ssl,
uv.AF_UNSPEC)
# See a comment on os_dup in create_connection
fileno = os_dup(sock.fileno())
try:
tcp._open(fileno)
tcp._attach_fileobj(sock)
tcp._open(sock.fileno())
tcp.listen(backlog)
except:
tcp._close()
raise
tcp._attach_fileobj(sock)
server._add_server(tcp)
server._ref()
@ -1695,46 +1692,23 @@ cdef class Loop:
raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock))
# libuv will set the socket to non-blocking mode, but
# we want Python socket object to notice that.
sock.setblocking(False)
waiter = self._new_future()
tr = TCPTransport.new(self, protocol, None, waiter)
try:
# Why we use os.dup here and other places
# ---------------------------------------
#
# Prerequisite: in Python 3.6, Python Socket Objects (PSO)
# were fixed to raise an OSError if the `socket.close()` call
# failed. So if the underlying FD is already closed,
# `socket.close()` call will fail with OSError(EBADF).
#
# Problem:
#
# - Vanilla asyncio uses the passed PSO directly. When the
# transport eventually closes the PSO, the PSO is marked as
# 'closed'. If the user decides to close the PSO after
# closing the loop, everything works normal in Python 3.5
# and 3.6.
#
# - asyncio+uvloop unwraps the FD from the passed PSO.
# Eventually the transport is closed and so the FD.
# If the user decides to close the PSO after closing the
# loop, an OSError(EBADF) will be raised in Python 3.6.
#
# All in all, duping FDs makes sense, because uvloop
# (and libuv) manage the FD once the user passes a PSO to
# `loop.create_connection`. We don't want the user to have
# any control of the FD once it is passed to uvloop.
# See also: https://github.com/python/asyncio/pull/449
fileno = os_dup(sock.fileno())
# libuv will make socket non-blocking
tr._open(fileno)
tr._attach_fileobj(sock)
tr._open(sock.fileno())
tr._init_protocol()
await waiter
except:
tr._close()
raise
tr._attach_fileobj(sock)
if ssl:
await ssl_waiter
return protocol._app_transport, app_protocol
@ -1829,26 +1803,26 @@ cdef class Loop:
'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock))
# libuv will set the socket to non-blocking mode, but
# we want Python socket object to notice that.
sock.setblocking(False)
pipe = UnixServer.new(self, protocol_factory, server, ssl)
try:
# See a comment on os_dup in create_connection
fileno = os_dup(sock.fileno())
pipe._open(fileno)
pipe._open(sock.fileno())
except:
pipe._close()
sock.close()
raise
pipe._attach_fileobj(sock)
try:
pipe.listen(backlog)
except:
pipe._close()
raise
pipe._attach_fileobj(sock)
server._add_server(pipe)
return server
@ -1913,21 +1887,22 @@ cdef class Loop:
'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock))
# libuv will set the socket to non-blocking mode, but
# we want Python socket object to notice that.
sock.setblocking(False)
waiter = self._new_future()
tr = UnixTransport.new(self, protocol, None, waiter)
try:
# See a comment on os_dup in create_connection
fileno = os_dup(sock.fileno())
# libuv will make socket non-blocking
tr._open(fileno)
tr._attach_fileobj(sock)
tr._open(sock.fileno())
tr._init_protocol()
await waiter
except:
tr._close()
raise
tr._attach_fileobj(sock)
if ssl:
await ssl_waiter
return protocol._app_transport, app_protocol
@ -2138,7 +2113,7 @@ cdef class Loop:
if not data:
return
self._inc_io_ref(sock)
socket_inc_io_ref(sock)
try:
try:
n = sock.send(data)
@ -2167,7 +2142,7 @@ cdef class Loop:
self._add_writer(sock, handle)
return await fut
finally:
self._dec_io_ref(sock)
socket_dec_io_ref(sock)
def sock_accept(self, sock):
"""Accept a connection.
@ -2204,7 +2179,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
self._inc_io_ref(sock)
socket_inc_io_ref(sock)
try:
fut = self._new_future()
if sock.family == uv.AF_UNIX:
@ -2216,7 +2191,7 @@ cdef class Loop:
self._sock_connect(fut, sock, address)
await fut
finally:
self._dec_io_ref(sock)
socket_dec_io_ref(sock)
async def connect_accepted_socket(self, protocol_factory, sock, *,
ssl=None):
@ -2238,9 +2213,6 @@ cdef class Loop:
raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock))
# See a comment on os_dup in create_connection
fileno = os_dup(sock.fileno())
app_protocol = protocol_factory()
waiter = self._new_future()
transport_waiter = None
@ -2266,8 +2238,9 @@ cdef class Loop:
raise ValueError(
'invalid socket family, expected AF_UNIX, AF_INET or AF_INET6')
transport._open(fileno)
transport._open(sock.fileno())
transport._init_protocol()
transport._attach_fileobj(sock)
await waiter
@ -2388,21 +2361,19 @@ cdef class Loop:
ReadTransport interface."""
cdef:
ReadUnixTransport transp
# See a comment on os_dup in create_connection
int fileno = os_dup(pipe.fileno())
waiter = self._new_future()
proto = proto_factory()
transp = ReadUnixTransport.new(self, proto, None, waiter)
transp._add_extra_info('pipe', pipe)
transp._attach_fileobj(pipe)
try:
transp._open(fileno)
transp._open(pipe.fileno())
transp._init_protocol()
await waiter
except:
transp.close()
raise
transp._attach_fileobj(pipe)
return transp, proto
async def connect_write_pipe(self, proto_factory, pipe):
@ -2414,21 +2385,19 @@ cdef class Loop:
WriteTransport interface."""
cdef:
WriteUnixTransport transp
# See a comment on os_dup in create_connection
int fileno = os_dup(pipe.fileno())
waiter = self._new_future()
proto = proto_factory()
transp = WriteUnixTransport.new(self, proto, None, waiter)
transp._add_extra_info('pipe', pipe)
transp._attach_fileobj(pipe)
try:
transp._open(fileno)
transp._open(pipe.fileno())
transp._init_protocol()
await waiter
except:
transp.close()
raise
transp._attach_fileobj(pipe)
return transp, proto
def add_signal_handler(self, sig, callback, *args):
@ -2562,7 +2531,7 @@ cdef class Loop:
sock.setblocking(False)
udp = UDPTransport.__new__(UDPTransport)
udp._init(self, uv.AF_UNSPEC)
udp.open(sock.family, os_dup(sock.fileno()))
udp.open(sock.family, sock.fileno())
udp._attach_fileobj(sock)
else:
reuse_address = bool(reuse_address)
@ -2609,11 +2578,15 @@ cdef class Loop:
if reuse_port:
self._sock_set_reuseport(udp._fileno())
socket = udp._get_socket()
if family == uv.AF_INET6:
socket.bind(('::', 0))
else:
socket.bind(('0.0.0.0', 0))
fd = udp._fileno()
sock = socket_socket(family, uv.SOCK_DGRAM, 0, fd)
try:
if family == uv.AF_INET6:
sock.bind(('::', 0))
else:
sock.bind(('0.0.0.0', 0))
finally:
sock.detach()
else:
lai = (<AddrInfo>lads).data
while lai is not NULL:
@ -2735,6 +2708,7 @@ cdef inline void __loop_free_buffer(Loop loop):
include "cbhandles.pyx"
include "pseudosock.pyx"
include "handles/handle.pyx"
include "handles/async_.pyx"

209
uvloop/pseudosock.pyx Normal file
View File

@ -0,0 +1,209 @@
cdef class PseudoSocket:
cdef:
int _family
int _type
int _proto
int _fd
object _peername
object _sockname
def __init__(self, int family, int type, int proto, int fd):
self._family = family
self._type = type
self._proto = proto
self._fd = fd
self._peername = None
self._sockname = None
cdef _na(self, what):
raise TypeError('transport sockets do not support {}'.format(what))
cdef _make_sock(self):
return socket_socket(self._family, self._type, self._proto, self._fd)
property family:
def __get__(self):
try:
return socket_AddressFamily(self._family)
except ValueError:
return self._family
property type:
def __get__(self):
try:
return socket_SocketKind(self._type)
except ValueError:
return self._type
property proto:
def __get__(self):
return self._proto
def __repr__(self):
s = ("<uvloop.PseudoSocket fd={}, family={!s}, "
"type={!s}, proto={}").format(self.fileno(), self.family,
self.type, self.proto)
if self._fd != -1:
try:
laddr = self.getsockname()
if laddr:
s += ", laddr=%s" % str(laddr)
except socket_error:
pass
try:
raddr = self.getpeername()
if raddr:
s += ", raddr=%s" % str(raddr)
except socket_error:
pass
s += '>'
return s
def __getstate__(self):
raise TypeError("Cannot serialize socket object")
def fileno(self):
return self._fd
def dup(self):
fd = os_dup(self._fd)
sock = socket_socket(self._family, self._type, self._proto, fileno=fd)
sock.settimeout(0)
return sock
def get_inheritable(self):
return os_get_inheritable(self._fd)
def set_inheritable(self):
os_set_inheritable(self._fd)
def ioctl(self, *args, **kwargs):
pass
def getsockopt(self, *args, **kwargs):
sock = self._make_sock()
try:
return sock.getsockopt(*args, **kwargs)
finally:
sock.detach()
def setsockopt(self, *args, **kwargs):
sock = self._make_sock()
try:
return sock.setsockopt(*args, **kwargs)
finally:
sock.detach()
def getpeername(self):
if self._peername is not None:
return self._peername
sock = self._make_sock()
try:
self._peername = sock.getpeername()
return self._peername
finally:
sock.detach()
def getsockname(self):
if self._sockname is not None:
return self._sockname
sock = self._make_sock()
try:
self._sockname = sock.getsockname()
return self._sockname
finally:
sock.detach()
def share(self, process_id):
sock = self._make_sock()
try:
return sock.share(process_id)
finally:
sock.detach()
def accept(self):
self._na('accept() method')
def connect(self, *args):
self._na('connect() method')
def connect_ex(self, *args):
self._na('connect_ex() method')
def bind(self, *args):
self._na('bind() method')
def listen(self, *args, **kwargs):
self._na('listen() method')
def makefile(self):
self._na('makefile() method')
def sendfile(self, *args, **kwargs):
self._na('sendfile() method')
def close(self):
self._na('close() method')
def detach(self):
self._na('detach() method')
def shutdown(self, *args):
self._na('shutdown() method')
def sendmsg_afalg(self, *args, **kwargs):
self._na('sendmsg_afalg() method')
def sendmsg(self):
self._na('sendmsg() method')
def sendto(self, *args, **kwargs):
self._na('sendto() method')
def send(self, *args, **kwargs):
self._na('send() method')
def sendall(self, *args, **kwargs):
self._na('sendall() method')
def recv_into(self, *args, **kwargs):
self._na('recv_into() method')
def recvfrom_into(self, *args, **kwargs):
self._na('recvfrom_into() method')
def recvmsg_into(self, *args, **kwargs):
self._na('recvmsg_into() method')
def recvmsg(self, *args, **kwargs):
self._na('recvmsg() method')
def recvfrom(self, *args, **kwargs):
self._na('recvfrom() method')
def recv(self, *args, **kwargs):
self._na('recv() method')
def settimeout(self, value):
if value == 0:
return
raise ValueError(
'settimeout(): only 0 timeout is allowed on transport sockets')
def gettimeout(self):
return 0
def setblocking(self, flag):
if not flag:
return
raise ValueError(
'setblocking(): transport sockets cannot be blocking')
def __enter__(self):
self._na('context manager protocol')
def __exit__(self, *err):
self._na('context manager protocol')