From 6cea9343927b551f6a3267bf1cd03c2923f50480 Mon Sep 17 00:00:00 2001 From: Mahmoud Hashemi Date: Wed, 27 Apr 2016 22:09:42 -0700 Subject: [PATCH] BufferedSocket is now threadsafe, no longer has a fileno, and there are ever more socketutils docs --- boltons/socketutils.py | 348 +++++++++++++++++++++++--------------- tests/test_socketutils.py | 8 +- 2 files changed, 212 insertions(+), 144 deletions(-) diff --git a/boltons/socketutils.py b/boltons/socketutils.py index c8d7881..c759ed7 100644 --- a/boltons/socketutils.py +++ b/boltons/socketutils.py @@ -33,6 +33,17 @@ contributions on this module. import time import socket +try: + from threading import RLock +except Exception: + class RLock(object): + 'Dummy reentrant lock for builds without threads' + def __enter__(self): + pass + + def __exit__(self, exctype, excinst, exctb): + pass + try: from typeutils import make_sentinel @@ -53,26 +64,42 @@ class BufferedSocket(object): This type has been tested against both the built-in socket type as well as those from gevent and eventlet. It also features support for sockets with timeouts set to 0 (aka nonblocking), provided the - caller is prepared to handle the EWOULDBLOCK exceptions. Much like - the built-in socket, the BufferedSocket is not intrinsically - threadsafe for higher-level protocols. + caller is prepared to handle the EWOULDBLOCK exceptions. Args: sock (socket): The connected socket to be wrapped. - timeout (float): The default timeout for sends and recvs, in seconds. + timeout (float): The default timeout for sends and recvs, in + seconds. Defaults to 10 seconds. maxsize (int): The default maximum number of bytes to be received into the buffer before it is considered full and raises an - exception. + exception. Defaults to 32 kilobytes. *timeout* and *maxsize* can both be overridden on individual socket operations. - All ``recv`` methods return bytestrings (:type:`bytes`) and can raise - :exc:`socket.error`. :exc:`Timeout`, :exc:`ConnectionClosed`, and - :exc:`NotFound` all inherit from :exc:`socket.error` and exist to - provide better error messages. + All ``recv`` methods return bytestrings (:type:`bytes`) and can + raise :exc:`socket.error`. :exc:`Timeout`, + :exc:`ConnectionClosed`, and :exc:`MessageTooLong` all inherit + from :exc:`socket.error` and exist to provide better error + messages. Received bytes are always buffered, even if an exception + is raised. Use :meth:`BufferedSocket.getrecvbuffer` to retrieve + partial recvs. + + BufferedSocket does not replace the built-in socket by any + means. While the overlapping parts of the API are kept parallel to + the built-in :type:`socket.socket`, BufferedSocket does not + inherit from socket, and most socket functionality is only + available on the underlying socket. :meth:`socket.getpeername`, + :meth:`socket.getsockname`, :meth:`socket.fileno`, and others are + only available on the underlying socket that is wrapped. Use the + ``BufferedSocket.sock`` attribute to access it. See the examples + for more information on how to use BufferedSockets with built-in + sockets. + + The BufferedSocket is threadsafe. Still, consider your protocol + before accessing a single socket from multiple threads. + """ - # TODO: recv_close() # receive until socket closed def __init__(self, sock, timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE): self.sock = sock @@ -82,11 +109,8 @@ class BufferedSocket(object): self.timeout = float(timeout) self.maxsize = int(maxsize) - def fileno(self): - """Returns the file descriptor of the underlying socket. Raises an - exception if the underlying socket is closed. - """ - return self.sock.fileno() + self.send_lock = RLock() + self.recv_lock = RLock() def settimeout(self, timeout): "Set the default timeout for future operations, in float seconds." @@ -98,6 +122,16 @@ class BufferedSocket(object): """ self.maxsize = maxsize + def getrecvbuffer(self): + "Returns the receive buffer bytestring (rbuf)." + with self.recv_lock: + return self.rbuf + + def getsendbuffer(self): + "Returns a copy of the send buffer list." + with self.send_lock: + return list(self.sbuf) + def recv(self, size, flags=0, timeout=_UNSET): """Returns **up to** *size* bytes, using the internal buffer before performing a single :meth:`socket.recv` operation. @@ -113,24 +147,25 @@ class BufferedSocket(object): If the operation does not complete in *timeout* seconds, a :exc:`Timeout` is raised. """ - if timeout is _UNSET: - timeout = self.timeout - if flags: - raise ValueError("non-zero flags not supported: %r" % flags) - if len(self.rbuf) >= size: - data, self.rbuf = self.rbuf[:size], self.rbuf[size:] - return data - size -= len(self.rbuf) - self.sock.settimeout(timeout) - try: - sock_data = self.sock.recv(size) - except socket.timeout: - raise Timeout(timeout) # check the rbuf attr for more - data = self.rbuf + sock_data - # don't empty buffer till after network communication is complete, - # to avoid data loss on transient / retry-able errors (e.g. read - # timeout) - self.rbuf = b'' + with self.recv_lock: + if timeout is _UNSET: + timeout = self.timeout + if flags: + raise ValueError("non-zero flags not supported: %r" % flags) + if len(self.rbuf) >= size: + data, self.rbuf = self.rbuf[:size], self.rbuf[size:] + return data + size -= len(self.rbuf) + self.sock.settimeout(timeout) + try: + sock_data = self.sock.recv(size) + except socket.timeout: + raise Timeout(timeout) # check the rbuf attr for more + data = self.rbuf + sock_data + # don't empty buffer till after network communication is complete, + # to avoid data loss on transient / retry-able errors (e.g. read + # timeout) + self.rbuf = b'' return data def peek(self, size, timeout=_UNSET): @@ -144,12 +179,30 @@ class BufferedSocket(object): set in the constructor of BufferedSocket. """ - if len(self.rbuf) >= size: - return self.rbuf[:size] - data = self.recv_size(size, timeout=timeout) - self.rbuf = data + self.rbuf + with self.recv_lock: + if len(self.rbuf) >= size: + return self.rbuf[:size] + data = self.recv_size(size, timeout=timeout) + self.rbuf = data + self.rbuf return data + def recv_close(self, maxsize=_UNSET, timeout=_UNSET): + """Receive until the connection is closed, up to *maxsize* bytes. If + more than *maxsize* bytes are received, raises :exc:`MessageTooLong`. + """ + with self.recv_lock: + if maxsize is _UNSET: + maxsize = self.maxsize + try: + recvd = self.recv_size(maxsize, timeout) + except ConnectionClosed: + ret, self.rbuf = self.rbuf, b'' + else: + # put extra received bytes (now in rbuf) after recvd + self.rbuf = recvd + self.rbuf + raise MessageTooLong(len(self.rbuf)) # check receive buffer + return ret + def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET): """Receive until *marker* is found, *timeout* is exceeded, or internal buffer reaches *maxsize*. @@ -163,44 +216,45 @@ class BufferedSocket(object): maxsize (int): The maximum size for the internal buffer. Defaults to the value set in the constructor. """ - 'read off of socket until the marker is found' - if maxsize is _UNSET: - maxsize = self.maxsize - if timeout is _UNSET: - timeout = self.timeout - recvd = bytearray(self.rbuf) - start = time.time() - sock = self.sock - if not timeout: # covers None (no timeout) and 0 (nonblocking) - sock.settimeout(timeout) - try: - while 1: - if maxsize is not None and len(recvd) >= maxsize: - raise NotFound(marker, len(recvd)) # check rbuf attr - if timeout: - cur_timeout = timeout - (time.time() - start) - if cur_timeout <= 0.0: - raise socket.timeout() - sock.settimeout(cur_timeout) - nxt = sock.recv(maxsize) - if not nxt: - msg = ('connection closed after reading %s bytes without' - ' finding symbol: %r' % (len(recvd), marker)) - raise ConnectionClosed(msg) # check the rbuf attr for more - recvd.extend(nxt) - offset = recvd.find(marker, -len(nxt) - len(marker)) - if offset >= 0: - offset += len(marker) # include marker in the return - break - except socket.timeout: - self.rbuf = bytes(recvd) - msg = ('read %s bytes without finding marker: %r' - % (len(recvd), marker)) - raise Timeout(timeout, msg) # check the rbuf attr for more - except Exception: - self.rbuf = bytes(recvd) - raise - val, self.rbuf = bytes(recvd[:offset]), bytes(recvd[offset:]) + with self.recv_lock: + if maxsize is _UNSET: + maxsize = self.maxsize + if timeout is _UNSET: + timeout = self.timeout + recvd = bytearray(self.rbuf) + start = time.time() + sock = self.sock + if not timeout: # covers None (no timeout) and 0 (nonblocking) + sock.settimeout(timeout) + try: + while 1: + if maxsize is not None and len(recvd) >= maxsize: + raise MessageTooLong(len(recvd), marker) # check rbuf + if timeout: + cur_timeout = timeout - (time.time() - start) + if cur_timeout <= 0.0: + raise socket.timeout() + sock.settimeout(cur_timeout) + nxt = sock.recv(maxsize) + if not nxt: + args = (len(recvd), marker) + msg = ('connection closed after reading %s bytes' + ' without finding symbol: %r' % args) + raise ConnectionClosed(msg) # check the recv buffer + recvd.extend(nxt) + offset = recvd.find(marker, -len(nxt) - len(marker)) + if offset >= 0: + offset += len(marker) # include marker in the return + break + except socket.timeout: + self.rbuf = bytes(recvd) + msg = ('read %s bytes without finding marker: %r' + % (len(recvd), marker)) + raise Timeout(timeout, msg) # check the recv buffer + except Exception: + self.rbuf = bytes(recvd) + raise + val, self.rbuf = bytes(recvd[:offset]), bytes(recvd[offset:]) return val def recv_size(self, size, timeout=_UNSET): @@ -218,43 +272,44 @@ class BufferedSocket(object): :exc:`Timeout` will be raised. If the connection is closed, a :exc:`ConnectionClosed` will be raised. """ - if timeout is _UNSET: - timeout = self.timeout - chunks = [] - total_bytes = 0 - try: - start = time.time() - self.sock.settimeout(timeout) - nxt = self.rbuf or self.sock.recv(size) - while nxt: - total_bytes += len(nxt) - if total_bytes >= size: - break - chunks.append(nxt) - if timeout: - cur_timeout = timeout - (time.time() - start) - if cur_timeout <= 0.0: - raise socket.timeout() - self.sock.settimeout(cur_timeout) - nxt = self.sock.recv(size - total_bytes) + with self.recv_lock: + if timeout is _UNSET: + timeout = self.timeout + chunks = [] + total_bytes = 0 + try: + start = time.time() + self.sock.settimeout(timeout) + nxt = self.rbuf or self.sock.recv(size) + while nxt: + total_bytes += len(nxt) + if total_bytes >= size: + break + chunks.append(nxt) + if timeout: + cur_timeout = timeout - (time.time() - start) + if cur_timeout <= 0.0: + raise socket.timeout() + self.sock.settimeout(cur_timeout) + nxt = self.sock.recv(size - total_bytes) + else: + msg = ('connection closed after reading %s of %s requested' + ' bytes' % (total_bytes, size)) + raise ConnectionClosed(msg) # check recv buffer + except socket.timeout: + self.rbuf = b''.join(chunks) + msg = 'read %s of %s bytes' % (total_bytes, size) + raise Timeout(timeout, msg) # check recv buffer + except Exception: + # received data is still buffered in the case of errors + self.rbuf = b''.join(chunks) + raise + extra_bytes = total_bytes - size + if extra_bytes: + last, self.rbuf = nxt[:-extra_bytes], nxt[-extra_bytes:] else: - msg = ('connection closed after reading %s of %s requested' - ' bytes' % (total_bytes, size)) - raise ConnectionClosed(msg) # check rbuf attribute for more - except socket.timeout: - self.rbuf = b''.join(chunks) - msg = 'read %s of %s bytes' % (total_bytes, size) - raise Timeout(timeout, msg) # check rbuf attribute for more - except Exception: - # received data is still buffered in the case of errors - self.rbuf = b''.join(chunks) - raise - extra_bytes = total_bytes - size - if extra_bytes: - last, self.rbuf = nxt[:-extra_bytes], nxt[-extra_bytes:] - else: - last, self.rbuf = nxt, b'' - chunks.append(last) + last, self.rbuf = nxt, b'' + chunks.append(last) return b''.join(chunks) def send(self, data, flags=0, timeout=_UNSET): @@ -273,37 +328,43 @@ class BufferedSocket(object): complete before *timeout*. """ - if timeout is _UNSET: - timeout = self.timeout - if flags: - raise ValueError("non-zero flags not supported") - sbuf = self.sbuf - sbuf.append(data) - if len(sbuf) > 1: - sbuf[:] = [b''.join(sbuf)] - self.sock.settimeout(timeout) - start = time.time() - try: - while sbuf[0]: - sent = self.sock.send(sbuf[0]) - sbuf[0] = sbuf[0][sent:] - if timeout: - cur_timeout = timeout - (time.time() - start) - if cur_timeout <= 0.0: - raise socket.timeout() - self.sock.settimeout(cur_timeout) - except socket.timeout: - raise Timeout(timeout, '%s bytes unsent' % len(sbuf[0])) + with self.send_lock: + if timeout is _UNSET: + timeout = self.timeout + if flags: + raise ValueError("non-zero flags not supported") + sbuf = self.sbuf + sbuf.append(data) + if len(sbuf) > 1: + sbuf[:] = [b''.join(sbuf)] + self.sock.settimeout(timeout) + start = time.time() + try: + while sbuf[0]: + sent = self.sock.send(sbuf[0]) + sbuf[0] = sbuf[0][sent:] + if timeout: + cur_timeout = timeout - (time.time() - start) + if cur_timeout <= 0.0: + raise socket.timeout() + self.sock.settimeout(cur_timeout) + except socket.timeout: + raise Timeout(timeout, '%s bytes unsent' % len(sbuf[0])) + return sendall = send def flush(self): "Send the contents of the internal send buffer." - self.send(b'') + with self.send_lock: + self.send(b'') + return def buffer(self, data): "Buffer *data* bytes for the next send operation." - self.sbuf.append(data) + with self.send_lock: + self.sbuf.append(data) + return class Error(socket.error, Exception): @@ -314,6 +375,19 @@ class ConnectionClosed(Error): pass +class MessageTooLong(Error): + """Only raised from :meth:`BufferedSocket.recv_until` when more than + *maxsize* bytes are read without the socket closing. + """ + def __init__(self, bytes_read=None, marker=None): + msg = 'message exceeded maximum size' + if bytes_read is not None: + msg += '%s bytes read' % (bytes_read,) + if marker is not None: + msg += '. Marker not found: %r' % (marker,) + super(MessageTooLong, self).__init__(msg) + + class Timeout(socket.timeout, Error): def __init__(self, timeout, extra=""): msg = 'socket operation timed out' @@ -324,12 +398,6 @@ class Timeout(socket.timeout, Error): super(Timeout, self).__init__(msg) -class NotFound(Error): - def __init__(self, marker, bytes_read): - msg = 'read %s bytes without finding marker: %r' % (marker, bytes_read) - super(NotFound, self).__init__(msg) - - class NetstringSocket(object): """ Reads and writes using the netstring protocol. diff --git a/tests/test_socketutils.py b/tests/test_socketutils.py index 5007674..f99ecd6 100644 --- a/tests/test_socketutils.py +++ b/tests/test_socketutils.py @@ -6,7 +6,7 @@ import threading from boltons.socketutils import (NetstringSocket, ConnectionClosed, NetstringMessageTooLong, - NotFound, + MessageTooLong, Timeout) @@ -107,9 +107,9 @@ def test_socketutils_netstring(): print("raised MessageTooLong correctly") try: client.bsock.recv_until(b'b', maxsize=4096) - raise Exception('recv_until did not raise NotFound') - except NotFound: - print("raised NotFound correctly") + raise Exception('recv_until did not raise MessageTooLong') + except MessageTooLong: + print("raised MessageTooLong correctly") assert client.bsock.recv_size(4097) == b'a' * 4096 + b',' print('correctly maintained buffer after exception raised')