mirror of https://github.com/mahmoud/boltons.git
more socketutils cleanup, esp renaming maxbytes to maxsize. generally getting it ready for wider release.
This commit is contained in:
parent
fbfe3bd4eb
commit
e72b1887b4
|
@ -1,8 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# TODO: test the settimeout(0) support on BufferedSocket (should work)
|
||||
# TODO: maybe add settimeout(0) support on the netstring socket
|
||||
|
||||
import time
|
||||
import socket
|
||||
|
||||
|
@ -15,7 +12,7 @@ except ImportError:
|
|||
|
||||
|
||||
DEFAULT_TIMEOUT = 10 # 10 seconds
|
||||
DEFAULT_MAXBYTES = 32 * 1024 # 32kb
|
||||
DEFAULT_MAXSIZE = 32 * 1024 # 32kb
|
||||
|
||||
|
||||
class BufferedSocket(object):
|
||||
|
@ -32,13 +29,13 @@ class BufferedSocket(object):
|
|||
"""
|
||||
# TODO: recv_close() # receive until socket closed
|
||||
def __init__(self, sock,
|
||||
timeout=DEFAULT_TIMEOUT, maxbytes=DEFAULT_MAXBYTES):
|
||||
timeout=DEFAULT_TIMEOUT, maxsize=DEFAULT_MAXSIZE):
|
||||
self.sock = sock
|
||||
self.sock.settimeout(None)
|
||||
self.rbuf = b''
|
||||
self.sbuf = []
|
||||
self.timeout = timeout
|
||||
self.maxbytes = maxbytes
|
||||
self.maxsize = maxsize
|
||||
|
||||
def fileno(self):
|
||||
return self.sock.fileno()
|
||||
|
@ -46,8 +43,8 @@ class BufferedSocket(object):
|
|||
def settimeout(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
def setmaxbytes(self, maxbytes):
|
||||
self.maxbytes = maxbytes
|
||||
def setmaxsize(self, maxsize):
|
||||
self.maxsize = maxsize
|
||||
|
||||
def recv(self, size, flags=0, timeout=_UNSET):
|
||||
if timeout is _UNSET:
|
||||
|
@ -67,17 +64,17 @@ class BufferedSocket(object):
|
|||
return data
|
||||
|
||||
def peek(self, n, timeout=_UNSET):
|
||||
'peek n bytes from the socket, but keep them in the buffer'
|
||||
'peek n bytes from the socket and keep them in the buffer'
|
||||
if len(self.rbuf) >= n:
|
||||
return self.rbuf[:n]
|
||||
data = self.recv_size(n, timeout=timeout)
|
||||
self.rbuf = data + self.rbuf
|
||||
return data
|
||||
|
||||
def recv_until(self, marker, timeout=_UNSET, maxbytes=_UNSET):
|
||||
def recv_until(self, marker, timeout=_UNSET, maxsize=_UNSET):
|
||||
'read off of socket until the marker is found'
|
||||
if maxbytes is _UNSET:
|
||||
maxbytes = self.maxbytes
|
||||
if maxsize is _UNSET:
|
||||
maxsize = self.maxsize
|
||||
if timeout is _UNSET:
|
||||
timeout = self.timeout
|
||||
recvd = bytearray(self.rbuf)
|
||||
|
@ -87,14 +84,14 @@ class BufferedSocket(object):
|
|||
sock.settimeout(timeout)
|
||||
try:
|
||||
while 1:
|
||||
if maxbytes is not None and len(recvd) >= maxbytes:
|
||||
if maxsize is not None and len(recvd) >= maxsize:
|
||||
raise NotFound(marker, len(recvd)) # check rbuf attr
|
||||
if timeout:
|
||||
cur_timeout = timeout - (time.time() - start)
|
||||
if cur_timeout <= 0.0:
|
||||
raise socket.timeout()
|
||||
sock.settimeout(cur_timeout)
|
||||
nxt = sock.recv(maxbytes)
|
||||
nxt = sock.recv(maxsize)
|
||||
if not nxt:
|
||||
msg = ('connection closed after reading %s bytes without'
|
||||
' finding symbol: %r' % (len(recvd), marker))
|
||||
|
@ -145,7 +142,7 @@ class BufferedSocket(object):
|
|||
msg = 'read %s of %s bytes' % (total_bytes, size)
|
||||
raise Timeout(timeout, msg) # check rbuf attribute for more
|
||||
except Exception:
|
||||
# data is always retained, regardless of errors
|
||||
# received data is still buffered in the case of errors
|
||||
self.rbuf = b''.join(chunks)
|
||||
raise
|
||||
extra_bytes = total_bytes - size
|
||||
|
@ -215,31 +212,51 @@ class NotFound(Error):
|
|||
class NetstringSocket(object):
|
||||
"""
|
||||
Reads and writes using the netstring protocol.
|
||||
|
||||
More info: https://en.wikipedia.org/wiki/Netstring
|
||||
Even more info: http://cr.yp.to/proto/netstrings.txt
|
||||
"""
|
||||
def __init__(self, sock, timeout=30, maxsize=32 * 1024):
|
||||
self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10
|
||||
self.bsock = BufferedSocket(sock)
|
||||
self.timeout = timeout
|
||||
self.maxsize = maxsize
|
||||
self.bsock = BufferedSocket(sock)
|
||||
self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10
|
||||
|
||||
def fileno(self):
|
||||
return self.bsock.fileno()
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
def setmaxsize(self, maxsize):
|
||||
self.maxsize = maxsize
|
||||
self.maxlensize = len(str(maxsize)) + 1 # len(str()) == log10
|
||||
self._msgsize_maxsize = len(str(maxsize)) + 1 # len(str()) == log10
|
||||
|
||||
def read_ns(self, timeout=_UNSET, maxsize=_UNSET):
|
||||
if timeout is _UNSET:
|
||||
timeout = self.timeout
|
||||
# start = time.time()
|
||||
size_pref = self.bsock.recv_until(b':', self.timeout, self.maxlensize)
|
||||
size = int(size_pref[:-1]) # netstrings must start with "size:"
|
||||
|
||||
size_prefix = self.bsock.recv_until(b':',
|
||||
timeout=self.timeout,
|
||||
maxsize=self._msgsize_maxsize)
|
||||
|
||||
size_bytes, sep, _ = size_prefix.partition(b':')
|
||||
|
||||
if not sep:
|
||||
raise NetstringInvalidSize('netstring messages must start with'
|
||||
' "size:", not %r' % size_prefix)
|
||||
try:
|
||||
size = int(size_bytes)
|
||||
except ValueError:
|
||||
raise NetstringInvalidSize('netstring message size must be valid'
|
||||
' integer, not %r' % size_bytes)
|
||||
|
||||
if size > self.maxsize:
|
||||
raise NetstringMessageTooLong(size, self.maxsize)
|
||||
payload = self.bsock.recv_size(size)
|
||||
if self.bsock.recv(1) != b',':
|
||||
raise NetstringProtocolError("expected trailing ',' after message")
|
||||
|
||||
return payload
|
||||
|
||||
def write_ns(self, payload):
|
||||
|
@ -254,6 +271,11 @@ class NetstringProtocolError(Error):
|
|||
pass
|
||||
|
||||
|
||||
class NetstringInvalidSize(NetstringProtocolError):
|
||||
def __init__(self, msg):
|
||||
super(NetstringMessageTooLong, self).__init__(msg)
|
||||
|
||||
|
||||
class NetstringMessageTooLong(NetstringProtocolError):
|
||||
def __init__(self, size, maxsize):
|
||||
msg = ('netstring message length exceeds configured maxsize: %s > %s'
|
||||
|
|
|
@ -106,7 +106,7 @@ def test_socketutils_netstring():
|
|||
except NetstringMessageTooLong:
|
||||
print("raised MessageTooLong correctly")
|
||||
try:
|
||||
client.bsock.recv_until(b'b', maxbytes=4096)
|
||||
client.bsock.recv_until(b'b', maxsize=4096)
|
||||
raise Exception('recv_until did not raise NotFound')
|
||||
except NotFound:
|
||||
print("raised NotFound correctly")
|
||||
|
|
Loading…
Reference in New Issue