BufferedSocket is now threadsafe, no longer has a fileno, and there are ever more socketutils docs

This commit is contained in:
Mahmoud Hashemi 2016-04-27 22:09:42 -07:00
parent 7f1f1c7424
commit 6cea934392
2 changed files with 212 additions and 144 deletions

View File

@ -33,6 +33,17 @@ contributions on this module.
import time import time
import socket import socket
try:
from threading import RLock
except Exception:
class RLock(object):
'Dummy reentrant lock for builds without threads'
def __enter__(self):
pass
def __exit__(self, exctype, excinst, exctb):
pass
try: try:
from typeutils import make_sentinel from typeutils import make_sentinel
@ -53,26 +64,42 @@ class BufferedSocket(object):
This type has been tested against both the built-in socket type as This type has been tested against both the built-in socket type as
well as those from gevent and eventlet. It also features support well as those from gevent and eventlet. It also features support
for sockets with timeouts set to 0 (aka nonblocking), provided the for sockets with timeouts set to 0 (aka nonblocking), provided the
caller is prepared to handle the EWOULDBLOCK exceptions. Much like caller is prepared to handle the EWOULDBLOCK exceptions.
the built-in socket, the BufferedSocket is not intrinsically
threadsafe for higher-level protocols.
Args: Args:
sock (socket): The connected socket to be wrapped. sock (socket): The connected socket to be wrapped.
timeout (float): The default timeout for sends and recvs, in seconds. timeout (float): The default timeout for sends and recvs, in
seconds. Defaults to 10 seconds.
maxsize (int): The default maximum number of bytes to be received maxsize (int): The default maximum number of bytes to be received
into the buffer before it is considered full and raises an into the buffer before it is considered full and raises an
exception. exception. Defaults to 32 kilobytes.
*timeout* and *maxsize* can both be overridden on individual socket *timeout* and *maxsize* can both be overridden on individual socket
operations. operations.
All ``recv`` methods return bytestrings (:type:`bytes`) and can raise All ``recv`` methods return bytestrings (:type:`bytes`) and can
:exc:`socket.error`. :exc:`Timeout`, :exc:`ConnectionClosed`, and raise :exc:`socket.error`. :exc:`Timeout`,
:exc:`NotFound` all inherit from :exc:`socket.error` and exist to :exc:`ConnectionClosed`, and :exc:`MessageTooLong` all inherit
provide better error messages. from :exc:`socket.error` and exist to provide better error
messages. Received bytes are always buffered, even if an exception
is raised. Use :meth:`BufferedSocket.getrecvbuffer` to retrieve
partial recvs.
BufferedSocket does not replace the built-in socket by any
means. While the overlapping parts of the API are kept parallel to
the built-in :type:`socket.socket`, BufferedSocket does not
inherit from socket, and most socket functionality is only
available on the underlying socket. :meth:`socket.getpeername`,
:meth:`socket.getsockname`, :meth:`socket.fileno`, and others are
only available on the underlying socket that is wrapped. Use the
``BufferedSocket.sock`` attribute to access it. See the examples
for more information on how to use BufferedSockets with built-in
sockets.
The BufferedSocket is threadsafe. Still, consider your protocol
before accessing a single socket from multiple threads.
""" """
# TODO: recv_close() # receive until socket closed
def __init__(self, sock, def __init__(self, sock,
timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE): timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE):
self.sock = sock self.sock = sock
@ -82,11 +109,8 @@ class BufferedSocket(object):
self.timeout = float(timeout) self.timeout = float(timeout)
self.maxsize = int(maxsize) self.maxsize = int(maxsize)
def fileno(self): self.send_lock = RLock()
"""Returns the file descriptor of the underlying socket. Raises an self.recv_lock = RLock()
exception if the underlying socket is closed.
"""
return self.sock.fileno()
def settimeout(self, timeout): def settimeout(self, timeout):
"Set the default timeout for future operations, in float seconds." "Set the default timeout for future operations, in float seconds."
@ -98,6 +122,16 @@ class BufferedSocket(object):
""" """
self.maxsize = maxsize self.maxsize = maxsize
def getrecvbuffer(self):
"Returns the receive buffer bytestring (rbuf)."
with self.recv_lock:
return self.rbuf
def getsendbuffer(self):
"Returns a copy of the send buffer list."
with self.send_lock:
return list(self.sbuf)
def recv(self, size, flags=0, timeout=_UNSET): def recv(self, size, flags=0, timeout=_UNSET):
"""Returns **up to** *size* bytes, using the internal buffer before """Returns **up to** *size* bytes, using the internal buffer before
performing a single :meth:`socket.recv` operation. performing a single :meth:`socket.recv` operation.
@ -113,24 +147,25 @@ class BufferedSocket(object):
If the operation does not complete in *timeout* seconds, a If the operation does not complete in *timeout* seconds, a
:exc:`Timeout` is raised. :exc:`Timeout` is raised.
""" """
if timeout is _UNSET: with self.recv_lock:
timeout = self.timeout if timeout is _UNSET:
if flags: timeout = self.timeout
raise ValueError("non-zero flags not supported: %r" % flags) if flags:
if len(self.rbuf) >= size: raise ValueError("non-zero flags not supported: %r" % flags)
data, self.rbuf = self.rbuf[:size], self.rbuf[size:] if len(self.rbuf) >= size:
return data data, self.rbuf = self.rbuf[:size], self.rbuf[size:]
size -= len(self.rbuf) return data
self.sock.settimeout(timeout) size -= len(self.rbuf)
try: self.sock.settimeout(timeout)
sock_data = self.sock.recv(size) try:
except socket.timeout: sock_data = self.sock.recv(size)
raise Timeout(timeout) # check the rbuf attr for more except socket.timeout:
data = self.rbuf + sock_data raise Timeout(timeout) # check the rbuf attr for more
# don't empty buffer till after network communication is complete, data = self.rbuf + sock_data
# to avoid data loss on transient / retry-able errors (e.g. read # don't empty buffer till after network communication is complete,
# timeout) # to avoid data loss on transient / retry-able errors (e.g. read
self.rbuf = b'' # timeout)
self.rbuf = b''
return data return data
def peek(self, size, timeout=_UNSET): def peek(self, size, timeout=_UNSET):
@ -144,12 +179,30 @@ class BufferedSocket(object):
set in the constructor of BufferedSocket. set in the constructor of BufferedSocket.
""" """
if len(self.rbuf) >= size: with self.recv_lock:
return self.rbuf[:size] if len(self.rbuf) >= size:
data = self.recv_size(size, timeout=timeout) return self.rbuf[:size]
self.rbuf = data + self.rbuf data = self.recv_size(size, timeout=timeout)
self.rbuf = data + self.rbuf
return data return data
def recv_close(self, maxsize=_UNSET, timeout=_UNSET):
"""Receive until the connection is closed, up to *maxsize* bytes. If
more than *maxsize* bytes are received, raises :exc:`MessageTooLong`.
"""
with self.recv_lock:
if maxsize is _UNSET:
maxsize = self.maxsize
try:
recvd = self.recv_size(maxsize, timeout)
except ConnectionClosed:
ret, self.rbuf = self.rbuf, b''
else:
# put extra received bytes (now in rbuf) after recvd
self.rbuf = recvd + self.rbuf
raise MessageTooLong(len(self.rbuf)) # check receive buffer
return ret
def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET): def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET):
"""Receive until *marker* is found, *timeout* is exceeded, or internal """Receive until *marker* is found, *timeout* is exceeded, or internal
buffer reaches *maxsize*. buffer reaches *maxsize*.
@ -163,44 +216,45 @@ class BufferedSocket(object):
maxsize (int): The maximum size for the internal buffer. maxsize (int): The maximum size for the internal buffer.
Defaults to the value set in the constructor. Defaults to the value set in the constructor.
""" """
'read off of socket until the marker is found' with self.recv_lock:
if maxsize is _UNSET: if maxsize is _UNSET:
maxsize = self.maxsize maxsize = self.maxsize
if timeout is _UNSET: if timeout is _UNSET:
timeout = self.timeout timeout = self.timeout
recvd = bytearray(self.rbuf) recvd = bytearray(self.rbuf)
start = time.time() start = time.time()
sock = self.sock sock = self.sock
if not timeout: # covers None (no timeout) and 0 (nonblocking) if not timeout: # covers None (no timeout) and 0 (nonblocking)
sock.settimeout(timeout) sock.settimeout(timeout)
try: try:
while 1: while 1:
if maxsize is not None and len(recvd) >= maxsize: if maxsize is not None and len(recvd) >= maxsize:
raise NotFound(marker, len(recvd)) # check rbuf attr raise MessageTooLong(len(recvd), marker) # check rbuf
if timeout: if timeout:
cur_timeout = timeout - (time.time() - start) cur_timeout = timeout - (time.time() - start)
if cur_timeout <= 0.0: if cur_timeout <= 0.0:
raise socket.timeout() raise socket.timeout()
sock.settimeout(cur_timeout) sock.settimeout(cur_timeout)
nxt = sock.recv(maxsize) nxt = sock.recv(maxsize)
if not nxt: if not nxt:
msg = ('connection closed after reading %s bytes without' args = (len(recvd), marker)
' finding symbol: %r' % (len(recvd), marker)) msg = ('connection closed after reading %s bytes'
raise ConnectionClosed(msg) # check the rbuf attr for more ' without finding symbol: %r' % args)
recvd.extend(nxt) raise ConnectionClosed(msg) # check the recv buffer
offset = recvd.find(marker, -len(nxt) - len(marker)) recvd.extend(nxt)
if offset >= 0: offset = recvd.find(marker, -len(nxt) - len(marker))
offset += len(marker) # include marker in the return if offset >= 0:
break offset += len(marker) # include marker in the return
except socket.timeout: break
self.rbuf = bytes(recvd) except socket.timeout:
msg = ('read %s bytes without finding marker: %r' self.rbuf = bytes(recvd)
% (len(recvd), marker)) msg = ('read %s bytes without finding marker: %r'
raise Timeout(timeout, msg) # check the rbuf attr for more % (len(recvd), marker))
except Exception: raise Timeout(timeout, msg) # check the recv buffer
self.rbuf = bytes(recvd) except Exception:
raise self.rbuf = bytes(recvd)
val, self.rbuf = bytes(recvd[:offset]), bytes(recvd[offset:]) raise
val, self.rbuf = bytes(recvd[:offset]), bytes(recvd[offset:])
return val return val
def recv_size(self, size, timeout=_UNSET): def recv_size(self, size, timeout=_UNSET):
@ -218,43 +272,44 @@ class BufferedSocket(object):
:exc:`Timeout` will be raised. If the connection is closed, a :exc:`Timeout` will be raised. If the connection is closed, a
:exc:`ConnectionClosed` will be raised. :exc:`ConnectionClosed` will be raised.
""" """
if timeout is _UNSET: with self.recv_lock:
timeout = self.timeout if timeout is _UNSET:
chunks = [] timeout = self.timeout
total_bytes = 0 chunks = []
try: total_bytes = 0
start = time.time() try:
self.sock.settimeout(timeout) start = time.time()
nxt = self.rbuf or self.sock.recv(size) self.sock.settimeout(timeout)
while nxt: nxt = self.rbuf or self.sock.recv(size)
total_bytes += len(nxt) while nxt:
if total_bytes >= size: total_bytes += len(nxt)
break if total_bytes >= size:
chunks.append(nxt) break
if timeout: chunks.append(nxt)
cur_timeout = timeout - (time.time() - start) if timeout:
if cur_timeout <= 0.0: cur_timeout = timeout - (time.time() - start)
raise socket.timeout() if cur_timeout <= 0.0:
self.sock.settimeout(cur_timeout) raise socket.timeout()
nxt = self.sock.recv(size - total_bytes) self.sock.settimeout(cur_timeout)
nxt = self.sock.recv(size - total_bytes)
else:
msg = ('connection closed after reading %s of %s requested'
' bytes' % (total_bytes, size))
raise ConnectionClosed(msg) # check recv buffer
except socket.timeout:
self.rbuf = b''.join(chunks)
msg = 'read %s of %s bytes' % (total_bytes, size)
raise Timeout(timeout, msg) # check recv buffer
except Exception:
# received data is still buffered in the case of errors
self.rbuf = b''.join(chunks)
raise
extra_bytes = total_bytes - size
if extra_bytes:
last, self.rbuf = nxt[:-extra_bytes], nxt[-extra_bytes:]
else: else:
msg = ('connection closed after reading %s of %s requested' last, self.rbuf = nxt, b''
' bytes' % (total_bytes, size)) chunks.append(last)
raise ConnectionClosed(msg) # check rbuf attribute for more
except socket.timeout:
self.rbuf = b''.join(chunks)
msg = 'read %s of %s bytes' % (total_bytes, size)
raise Timeout(timeout, msg) # check rbuf attribute for more
except Exception:
# received data is still buffered in the case of errors
self.rbuf = b''.join(chunks)
raise
extra_bytes = total_bytes - size
if extra_bytes:
last, self.rbuf = nxt[:-extra_bytes], nxt[-extra_bytes:]
else:
last, self.rbuf = nxt, b''
chunks.append(last)
return b''.join(chunks) return b''.join(chunks)
def send(self, data, flags=0, timeout=_UNSET): def send(self, data, flags=0, timeout=_UNSET):
@ -273,37 +328,43 @@ class BufferedSocket(object):
complete before *timeout*. complete before *timeout*.
""" """
if timeout is _UNSET: with self.send_lock:
timeout = self.timeout if timeout is _UNSET:
if flags: timeout = self.timeout
raise ValueError("non-zero flags not supported") if flags:
sbuf = self.sbuf raise ValueError("non-zero flags not supported")
sbuf.append(data) sbuf = self.sbuf
if len(sbuf) > 1: sbuf.append(data)
sbuf[:] = [b''.join(sbuf)] if len(sbuf) > 1:
self.sock.settimeout(timeout) sbuf[:] = [b''.join(sbuf)]
start = time.time() self.sock.settimeout(timeout)
try: start = time.time()
while sbuf[0]: try:
sent = self.sock.send(sbuf[0]) while sbuf[0]:
sbuf[0] = sbuf[0][sent:] sent = self.sock.send(sbuf[0])
if timeout: sbuf[0] = sbuf[0][sent:]
cur_timeout = timeout - (time.time() - start) if timeout:
if cur_timeout <= 0.0: cur_timeout = timeout - (time.time() - start)
raise socket.timeout() if cur_timeout <= 0.0:
self.sock.settimeout(cur_timeout) raise socket.timeout()
except socket.timeout: self.sock.settimeout(cur_timeout)
raise Timeout(timeout, '%s bytes unsent' % len(sbuf[0])) except socket.timeout:
raise Timeout(timeout, '%s bytes unsent' % len(sbuf[0]))
return
sendall = send sendall = send
def flush(self): def flush(self):
"Send the contents of the internal send buffer." "Send the contents of the internal send buffer."
self.send(b'') with self.send_lock:
self.send(b'')
return
def buffer(self, data): def buffer(self, data):
"Buffer *data* bytes for the next send operation." "Buffer *data* bytes for the next send operation."
self.sbuf.append(data) with self.send_lock:
self.sbuf.append(data)
return
class Error(socket.error, Exception): class Error(socket.error, Exception):
@ -314,6 +375,19 @@ class ConnectionClosed(Error):
pass pass
class MessageTooLong(Error):
"""Only raised from :meth:`BufferedSocket.recv_until` when more than
*maxsize* bytes are read without the socket closing.
"""
def __init__(self, bytes_read=None, marker=None):
msg = 'message exceeded maximum size'
if bytes_read is not None:
msg += '%s bytes read' % (bytes_read,)
if marker is not None:
msg += '. Marker not found: %r' % (marker,)
super(MessageTooLong, self).__init__(msg)
class Timeout(socket.timeout, Error): class Timeout(socket.timeout, Error):
def __init__(self, timeout, extra=""): def __init__(self, timeout, extra=""):
msg = 'socket operation timed out' msg = 'socket operation timed out'
@ -324,12 +398,6 @@ class Timeout(socket.timeout, Error):
super(Timeout, self).__init__(msg) super(Timeout, self).__init__(msg)
class NotFound(Error):
def __init__(self, marker, bytes_read):
msg = 'read %s bytes without finding marker: %r' % (marker, bytes_read)
super(NotFound, self).__init__(msg)
class NetstringSocket(object): class NetstringSocket(object):
""" """
Reads and writes using the netstring protocol. Reads and writes using the netstring protocol.

View File

@ -6,7 +6,7 @@ import threading
from boltons.socketutils import (NetstringSocket, from boltons.socketutils import (NetstringSocket,
ConnectionClosed, ConnectionClosed,
NetstringMessageTooLong, NetstringMessageTooLong,
NotFound, MessageTooLong,
Timeout) Timeout)
@ -107,9 +107,9 @@ def test_socketutils_netstring():
print("raised MessageTooLong correctly") print("raised MessageTooLong correctly")
try: try:
client.bsock.recv_until(b'b', maxsize=4096) client.bsock.recv_until(b'b', maxsize=4096)
raise Exception('recv_until did not raise NotFound') raise Exception('recv_until did not raise MessageTooLong')
except NotFound: except MessageTooLong:
print("raised NotFound correctly") print("raised MessageTooLong correctly")
assert client.bsock.recv_size(4097) == b'a' * 4096 + b',' assert client.bsock.recv_size(4097) == b'a' * 4096 + b','
print('correctly maintained buffer after exception raised') print('correctly maintained buffer after exception raised')