more socketutils cleanup, esp renaming maxbytes to maxsize. generally getting it ready for wider release.

This commit is contained in:
Mahmoud Hashemi 2016-04-27 01:51:43 -07:00
parent fbfe3bd4eb
commit e72b1887b4
2 changed files with 44 additions and 22 deletions

View File

@ -1,8 +1,5 @@
# -*- coding: utf-8 -*-
# TODO: test the settimeout(0) support on BufferedSocket (should work)
# TODO: maybe add settimeout(0) support on the netstring socket
import time
import socket
@ -15,7 +12,7 @@ except ImportError:
DEFAULT_TIMEOUT = 10 # 10 seconds
DEFAULT_MAXBYTES = 32 * 1024 # 32kb
DEFAULT_MAXSIZE = 32 * 1024 # 32kb
class BufferedSocket(object):
@ -32,13 +29,13 @@ class BufferedSocket(object):
"""
# TODO: recv_close() # receive until socket closed
def __init__(self, sock,
timeout=DEFAULT_TIMEOUT, maxbytes=DEFAULT_MAXBYTES):
timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE):
self.sock = sock
self.sock.settimeout(None)
self.rbuf = b''
self.sbuf = []
self.timeout = timeout
self.maxbytes = maxbytes
self.maxsize = maxsize
def fileno(self):
return self.sock.fileno()
@ -46,8 +43,8 @@ class BufferedSocket(object):
def settimeout(self, timeout):
self.timeout = timeout
def setmaxbytes(self, maxbytes):
self.maxbytes = maxbytes
def setmaxsize(self, maxsize):
self.maxsize = maxsize
def recv(self, size, flags=0, timeout=_UNSET):
if timeout is _UNSET:
@ -67,17 +64,17 @@ class BufferedSocket(object):
return data
def peek(self, n, timeout=_UNSET):
'peek n bytes from the socket, but keep them in the buffer'
'peek n bytes from the socket and keep them in the buffer'
if len(self.rbuf) >= n:
return self.rbuf[:n]
data = self.recv_size(n, timeout=timeout)
self.rbuf = data + self.rbuf
return data
def recv_until(self, marker, timeout=_UNSET, maxbytes=_UNSET):
def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET):
'read off of socket until the marker is found'
if maxbytes is _UNSET:
maxbytes = self.maxbytes
if maxsize is _UNSET:
maxsize = self.maxsize
if timeout is _UNSET:
timeout = self.timeout
recvd = bytearray(self.rbuf)
@ -87,14 +84,14 @@ class BufferedSocket(object):
sock.settimeout(timeout)
try:
while 1:
if maxbytes is not None and len(recvd) >= maxbytes:
if maxsize is not None and len(recvd) >= maxsize:
raise NotFound(marker, len(recvd)) # check rbuf attr
if timeout:
cur_timeout = timeout - (time.time() - start)
if cur_timeout <= 0.0:
raise socket.timeout()
sock.settimeout(cur_timeout)
nxt = sock.recv(maxbytes)
nxt = sock.recv(maxsize)
if not nxt:
msg = ('connection closed after reading %s bytes without'
' finding symbol: %r' % (len(recvd), marker))
@ -145,7 +142,7 @@ class BufferedSocket(object):
msg = 'read %s of %s bytes' % (total_bytes, size)
raise Timeout(timeout, msg) # check rbuf attribute for more
except Exception:
# data is always retained, regardless of errors
# received data is still buffered in the case of errors
self.rbuf = b''.join(chunks)
raise
extra_bytes = total_bytes - size
@ -215,31 +212,51 @@ class NotFound(Error):
class NetstringSocket(object):
"""
Reads and writes using the netstring protocol.
More info: https://en.wikipedia.org/wiki/Netstring
Even more info: http://cr.yp.to/proto/netstrings.txt
"""
def __init__(self, sock, timeout=30, maxsize=32 * 1024):
self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10
self.bsock = BufferedSocket(sock)
self.timeout = timeout
self.maxsize = maxsize
self.bsock = BufferedSocket(sock)
self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10
def fileno(self):
return self.bsock.fileno()
def settimeout(self, timeout):
self.timeout = timeout
def setmaxsize(self, maxsize):
self.maxsize = maxsize
self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10
self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10
def read_ns(self, timeout=_UNSET, maxsize=_UNSET):
if timeout is _UNSET:
timeout = self.timeout
# start = time.time()
size_pref = self.bsock.recv_until(b':', self.timeout, self.maxlensize)
size = int(size_pref[:-1]) # netstrings must start with "size:"
size_prefix = self.bsock.recv_until(b':',
timeout=self.timeout,
maxsize=self._msgsize_maxsize)
size_bytes, sep, _ = size_prefix.partition(b':')
if not sep:
raise NetstringInvalidSize('netstring messages must start with'
' "size:", not %r' % size_prefix)
try:
size = int(size_bytes)
except ValueError:
raise NetstringInvalidSize('netstring message size must be valid'
' integer, not %r' % size_bytes)
if size > self.maxsize:
raise NetstringMessageTooLong(size, self.maxsize)
payload = self.bsock.recv_size(size)
if self.bsock.recv(1) != b',':
raise NetstringProtocolError("expected trailing ',' after message")
return payload
def write_ns(self, payload):
@ -254,6 +271,11 @@ class NetstringProtocolError(Error):
pass
class NetstringInvalidSize(NetstringProtocolError):
def __init__(self, msg):
super(NetstringMessageTooLong, self).__init__(msg)
class NetstringMessageTooLong(NetstringProtocolError):
def __init__(self, size, maxsize):
msg = ('netstring message length exceeds configured maxsize: %s > %s'

View File

@ -106,7 +106,7 @@ def test_socketutils_netstring():
except NetstringMessageTooLong:
print("raised MessageTooLong correctly")
try:
client.bsock.recv_until(b'b', maxbytes=4096)
client.bsock.recv_until(b'b', maxsize=4096)
raise Exception('recv_until did not raise NotFound')
except NotFound:
print("raised NotFound correctly")