diff --git a/boltons/socketutils.py b/boltons/socketutils.py index e0803bb..97792aa 100644 --- a/boltons/socketutils.py +++ b/boltons/socketutils.py @@ -1,8 +1,5 @@ # -*- coding: utf-8 -*- -# TODO: test the settimeout(0) support on BufferedSocket (should work) -# TODO: maybe add settimeout(0) support on the netstring socket - import time import socket @@ -15,7 +12,7 @@ except ImportError: DEFAULT_TIMEOUT = 10 # 10 seconds -DEFAULT_MAXBYTES = 32 * 1024 # 32kb +DEFAULT_MAXSIZE = 32 * 1024 # 32kb class BufferedSocket(object): @@ -32,13 +29,13 @@ class BufferedSocket(object): """ # TODO: recv_close() # receive until socket closed def __init__(self, sock, - timeout=DEFAULT_TIMEOUT, maxbytes=DEFAULT_MAXBYTES): + timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE): self.sock = sock self.sock.settimeout(None) self.rbuf = b'' self.sbuf = [] self.timeout = timeout - self.maxbytes = maxbytes + self.maxsize = maxsize def fileno(self): return self.sock.fileno() @@ -46,8 +43,8 @@ class BufferedSocket(object): def settimeout(self, timeout): self.timeout = timeout - def setmaxbytes(self, maxbytes): - self.maxbytes = maxbytes + def setmaxsize(self, maxsize): + self.maxsize = maxsize def recv(self, size, flags=0, timeout=_UNSET): if timeout is _UNSET: @@ -67,17 +64,17 @@ class BufferedSocket(object): return data def peek(self, n, timeout=_UNSET): - 'peek n bytes from the socket, but keep them in the buffer' + 'peek n bytes from the socket and keep them in the buffer' if len(self.rbuf) >= n: return self.rbuf[:n] data = self.recv_size(n, timeout=timeout) self.rbuf = data + self.rbuf return data - def recv_until(self, marker, timeout=_UNSET, maxbytes=_UNSET): + def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET): 'read off of socket until the marker is found' - if maxbytes is _UNSET: - maxbytes = self.maxbytes + if maxsize is _UNSET: + maxsize = self.maxsize if timeout is _UNSET: timeout = self.timeout recvd = bytearray(self.rbuf) @@ -87,14 +84,14 @@ class BufferedSocket(object): sock.settimeout(timeout) try: while 1: - if maxbytes is not None and len(recvd) >= maxbytes: + if maxsize is not None and len(recvd) >= maxsize: raise NotFound(marker, len(recvd)) # check rbuf attr if timeout: cur_timeout = timeout - (time.time() - start) if cur_timeout <= 0.0: raise socket.timeout() sock.settimeout(cur_timeout) - nxt = sock.recv(maxbytes) + nxt = sock.recv(maxsize) if not nxt: msg = ('connection closed after reading %s bytes without' ' finding symbol: %r' % (len(recvd), marker)) @@ -145,7 +142,7 @@ class BufferedSocket(object): msg = 'read %s of %s bytes' % (total_bytes, size) raise Timeout(timeout, msg) # check rbuf attribute for more except Exception: - # data is always retained, regardless of errors + # received data is still buffered in the case of errors self.rbuf = b''.join(chunks) raise extra_bytes = total_bytes - size @@ -215,31 +212,51 @@ class NotFound(Error): class NetstringSocket(object): """ Reads and writes using the netstring protocol. + + More info: https://en.wikipedia.org/wiki/Netstring + Even more info: http://cr.yp.to/proto/netstrings.txt """ def __init__(self, sock, timeout=30, maxsize=32 * 1024): - self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10 + self.bsock = BufferedSocket(sock) self.timeout = timeout self.maxsize = maxsize - self.bsock = BufferedSocket(sock) + self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10 + + def fileno(self): + return self.bsock.fileno() def settimeout(self, timeout): self.timeout = timeout def setmaxsize(self, maxsize): self.maxsize = maxsize - self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10 + self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10 def read_ns(self, timeout=_UNSET, maxsize=_UNSET): if timeout is _UNSET: timeout = self.timeout - # start = time.time() - size_pref = self.bsock.recv_until(b':', self.timeout, self.maxlensize) - size = int(size_pref[:-1]) # netstrings must start with "size:" + + size_prefix = self.bsock.recv_until(b':', + timeout=self.timeout, + maxsize=self._msgsize_maxsize) + + size_bytes, sep, _ = size_prefix.partition(b':') + + if not sep: + raise NetstringInvalidSize('netstring messages must start with' + ' "size:", not %r' % size_prefix) + try: + size = int(size_bytes) + except ValueError: + raise NetstringInvalidSize('netstring message size must be valid' + ' integer, not %r' % size_bytes) + if size > self.maxsize: raise NetstringMessageTooLong(size, self.maxsize) payload = self.bsock.recv_size(size) if self.bsock.recv(1) != b',': raise NetstringProtocolError("expected trailing ',' after message") + return payload def write_ns(self, payload): @@ -254,6 +271,11 @@ class NetstringProtocolError(Error): pass +class NetstringInvalidSize(NetstringProtocolError): + def __init__(self, msg): + super(NetstringMessageTooLong, self).__init__(msg) + + class NetstringMessageTooLong(NetstringProtocolError): def __init__(self, size, maxsize): msg = ('netstring message length exceeds configured maxsize: %s > %s' diff --git a/tests/test_socketutils.py b/tests/test_socketutils.py index f9d867d..5007674 100644 --- a/tests/test_socketutils.py +++ b/tests/test_socketutils.py @@ -106,7 +106,7 @@ def test_socketutils_netstring(): except NetstringMessageTooLong: print("raised MessageTooLong correctly") try: - client.bsock.recv_until(b'b', maxbytes=4096) + client.bsock.recv_until(b'b', maxsize=4096) raise Exception('recv_until did not raise NotFound') except NotFound: print("raised NotFound correctly")