Add tcp.Reader.safe_read, use it in socks and websockets
safe_read is guaranteed to raise or return a byte string of the requested length. It's particularly useful for implementing binary protocols.
This commit is contained in:
parent
08b2e2a6a9
commit
f2bc58cdd2
|
@ -52,20 +52,6 @@ METHOD = utils.BiDi(
|
|||
)
|
||||
|
||||
|
||||
def _read(f, n):
|
||||
try:
|
||||
d = f.read(n)
|
||||
if len(d) == n:
|
||||
return d
|
||||
else:
|
||||
raise SocksError(
|
||||
REP.GENERAL_SOCKS_SERVER_FAILURE,
|
||||
"Incomplete Read"
|
||||
)
|
||||
except socket.error as e:
|
||||
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e))
|
||||
|
||||
|
||||
class ClientGreeting(object):
|
||||
__slots__ = ("ver", "methods")
|
||||
|
||||
|
@ -75,9 +61,9 @@ class ClientGreeting(object):
|
|||
|
||||
@classmethod
|
||||
def from_file(cls, f):
|
||||
ver, nmethods = struct.unpack("!BB", _read(f, 2))
|
||||
ver, nmethods = struct.unpack("!BB", f.safe_read(2))
|
||||
methods = array.array("B")
|
||||
methods.fromstring(_read(f, nmethods))
|
||||
methods.fromstring(f.safe_read(nmethods))
|
||||
return cls(ver, methods)
|
||||
|
||||
def to_file(self, f):
|
||||
|
@ -94,7 +80,7 @@ class ServerGreeting(object):
|
|||
|
||||
@classmethod
|
||||
def from_file(cls, f):
|
||||
ver, method = struct.unpack("!BB", _read(f, 2))
|
||||
ver, method = struct.unpack("!BB", f.safe_read(2))
|
||||
return cls(ver, method)
|
||||
|
||||
def to_file(self, f):
|
||||
|
@ -112,27 +98,27 @@ class Message(object):
|
|||
|
||||
@classmethod
|
||||
def from_file(cls, f):
|
||||
ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4))
|
||||
ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))
|
||||
if rsv != 0x00:
|
||||
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE,
|
||||
"Socks Request: Invalid reserved byte: %s" % rsv)
|
||||
|
||||
if atyp == ATYP.IPV4_ADDRESS:
|
||||
# We use tnoa here as ntop is not commonly available on Windows.
|
||||
host = socket.inet_ntoa(_read(f, 4))
|
||||
host = socket.inet_ntoa(f.safe_read(4))
|
||||
use_ipv6 = False
|
||||
elif atyp == ATYP.IPV6_ADDRESS:
|
||||
host = socket.inet_ntop(socket.AF_INET6, _read(f, 16))
|
||||
host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16))
|
||||
use_ipv6 = True
|
||||
elif atyp == ATYP.DOMAINNAME:
|
||||
length, = struct.unpack("!B", _read(f, 1))
|
||||
host = _read(f, length)
|
||||
length, = struct.unpack("!B", f.safe_read(1))
|
||||
host = f.safe_read(length)
|
||||
use_ipv6 = False
|
||||
else:
|
||||
raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
|
||||
"Socks Request: Unknown ATYP: %s" % atyp)
|
||||
|
||||
port, = struct.unpack("!H", _read(f, 2))
|
||||
port, = struct.unpack("!H", f.safe_read(2))
|
||||
addr = tcp.Address((host, port), use_ipv6=use_ipv6)
|
||||
return cls(ver, msg, atyp, addr)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3
|
|||
|
||||
class NetLibError(Exception): pass
|
||||
class NetLibDisconnect(NetLibError): pass
|
||||
class NetLibIncomplete(NetLibError): pass
|
||||
class NetLibTimeout(NetLibError): pass
|
||||
class NetLibSSLError(NetLibError): pass
|
||||
|
||||
|
@ -195,10 +196,23 @@ class Reader(_FileLike):
|
|||
break
|
||||
return result
|
||||
|
||||
def safe_read(self, length):
|
||||
"""
|
||||
Like .read, but is guaranteed to either return length bytes, or
|
||||
raise an exception.
|
||||
"""
|
||||
result = self.read(length)
|
||||
if length != -1 and len(result) != length:
|
||||
raise NetLibIncomplete(
|
||||
"Expected %s bytes, got %s"%(length, len(result))
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Address(object):
|
||||
"""
|
||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
|
||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and
|
||||
ipv6 information.
|
||||
"""
|
||||
def __init__(self, address, use_ipv6=False):
|
||||
self.address = tuple(address)
|
||||
|
@ -247,22 +261,28 @@ def close_socket(sock):
|
|||
"""
|
||||
try:
|
||||
# We already indicate that we close our end.
|
||||
sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
|
||||
# may raise "Transport endpoint is not connected" on Linux
|
||||
sock.shutdown(socket.SHUT_WR)
|
||||
|
||||
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
|
||||
# pending readable data could lead to an immediate RST being sent (which is the case on Windows).
|
||||
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
|
||||
# readable data could lead to an immediate RST being sent (which is the
|
||||
# case on Windows).
|
||||
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
|
||||
#
|
||||
# This in turn results in the following issue: If we send an error page to the client and then close the socket,
|
||||
# the RST may be received by the client before the error page and the users sees a connection error rather than
|
||||
# the error page. Thus, we try to empty the read buffer on Windows first.
|
||||
# (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
|
||||
# This in turn results in the following issue: If we send an error page
|
||||
# to the client and then close the socket, the RST may be received by
|
||||
# the client before the error page and the users sees a connection
|
||||
# error rather than the error page. Thus, we try to empty the read
|
||||
# buffer on Windows first. (see
|
||||
# https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
|
||||
#
|
||||
|
||||
if os.name == "nt": # pragma: no cover
|
||||
# We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
|
||||
# Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
|
||||
# recv() would block infinitely.
|
||||
# As a workaround, we set a timeout here even if we are in blocking mode.
|
||||
# We cannot rely on the shutdown()-followed-by-read()-eof technique
|
||||
# proposed by the page above: Some remote machines just don't send
|
||||
# a TCP FIN, which would leave us in the unfortunate situation that
|
||||
# recv() would block infinitely. As a workaround, we set a timeout
|
||||
# here even if we are in blocking mode.
|
||||
sock.settimeout(sock.gettimeout() or 20)
|
||||
|
||||
# limit at a megabyte so that we don't read infinitely
|
||||
|
@ -292,10 +312,10 @@ class _Connection(object):
|
|||
|
||||
def finish(self):
|
||||
self.finished = True
|
||||
|
||||
# If we have an SSL connection, wfile.close == connection.close
|
||||
# (We call _FileLike.set_descriptor(conn))
|
||||
# Closing the socket is not our task, therefore we don't call close then.
|
||||
# Closing the socket is not our task, therefore we don't call close
|
||||
# then.
|
||||
if type(self.connection) != SSL.Connection:
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
try:
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
import struct
|
||||
import io
|
||||
|
||||
from . import utils, odict
|
||||
from . import utils, odict, tcp
|
||||
|
||||
# Colleciton of utility functions that implement small portions of the RFC6455
|
||||
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
||||
|
@ -217,8 +217,8 @@ class FrameHeader:
|
|||
"""
|
||||
read a websockets frame header
|
||||
"""
|
||||
first_byte = utils.bytes_to_int(fp.read(1))
|
||||
second_byte = utils.bytes_to_int(fp.read(1))
|
||||
first_byte = utils.bytes_to_int(fp.safe_read(1))
|
||||
second_byte = utils.bytes_to_int(fp.safe_read(1))
|
||||
|
||||
fin = utils.getbit(first_byte, 7)
|
||||
rsv1 = utils.getbit(first_byte, 6)
|
||||
|
@ -235,13 +235,13 @@ class FrameHeader:
|
|||
if length_code <= 125:
|
||||
payload_length = length_code
|
||||
elif length_code == 126:
|
||||
payload_length = utils.bytes_to_int(fp.read(2))
|
||||
payload_length = utils.bytes_to_int(fp.safe_read(2))
|
||||
elif length_code == 127:
|
||||
payload_length = utils.bytes_to_int(fp.read(8))
|
||||
payload_length = utils.bytes_to_int(fp.safe_read(8))
|
||||
|
||||
# masking key only present if mask bit set
|
||||
if mask_bit == 1:
|
||||
masking_key = fp.read(4)
|
||||
masking_key = fp.safe_read(4)
|
||||
else:
|
||||
masking_key = None
|
||||
|
||||
|
@ -319,7 +319,7 @@ class Frame(object):
|
|||
Construct a websocket frame from an in-memory bytestring
|
||||
to construct a frame from a stream of bytes, use from_file() directly
|
||||
"""
|
||||
return cls.from_file(io.BytesIO(bytestring))
|
||||
return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
|
||||
|
||||
def human_readable(self):
|
||||
hdr = self.header.human_readable()
|
||||
|
@ -351,7 +351,7 @@ class Frame(object):
|
|||
stream or a disk or an in memory stream reader
|
||||
"""
|
||||
header = FrameHeader.from_file(fp)
|
||||
payload = fp.read(header.payload_length)
|
||||
payload = fp.safe_read(header.payload_length)
|
||||
|
||||
if header.mask == 1 and header.masking_key:
|
||||
payload = Masker(header.masking_key)(payload)
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_read_http_body_request():
|
|||
|
||||
def test_read_http_body_response():
|
||||
h = odict.ODictCaseless()
|
||||
s = cStringIO.StringIO("testing")
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
|
||||
|
||||
|
||||
|
@ -135,11 +135,11 @@ def test_read_http_body():
|
|||
|
||||
# test no content length: limit > actual content
|
||||
h = odict.ODictCaseless()
|
||||
s = cStringIO.StringIO("testing")
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7
|
||||
|
||||
# test no content length: limit < actual content
|
||||
s = cStringIO.StringIO("testing")
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
tutils.raises(
|
||||
http.HttpError,
|
||||
http.read_http_body,
|
||||
|
@ -149,7 +149,7 @@ def test_read_http_body():
|
|||
# test chunked
|
||||
h = odict.ODictCaseless()
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
|
||||
s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n"))
|
||||
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import tutils
|
|||
|
||||
|
||||
def test_client_greeting():
|
||||
raw = StringIO("\x05\x02\x00\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x02\x00\xBE\xEF")
|
||||
out = StringIO()
|
||||
msg = socks.ClientGreeting.from_file(raw)
|
||||
msg.to_file(out)
|
||||
|
@ -20,7 +20,7 @@ def test_client_greeting():
|
|||
|
||||
|
||||
def test_server_greeting():
|
||||
raw = StringIO("\x05\x02")
|
||||
raw = tutils.treader("\x05\x02")
|
||||
out = StringIO()
|
||||
msg = socks.ServerGreeting.from_file(raw)
|
||||
msg.to_file(out)
|
||||
|
@ -31,7 +31,7 @@ def test_server_greeting():
|
|||
|
||||
|
||||
def test_message():
|
||||
raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
|
||||
out = StringIO()
|
||||
msg = socks.Message.from_file(raw)
|
||||
assert raw.read(2) == "\xBE\xEF"
|
||||
|
@ -46,7 +46,7 @@ def test_message():
|
|||
|
||||
def test_message_ipv4():
|
||||
# Test ATYP=0x01 (IPV4)
|
||||
raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
out = StringIO()
|
||||
msg = socks.Message.from_file(raw)
|
||||
assert raw.read(2) == "\xBE\xEF"
|
||||
|
@ -62,7 +62,7 @@ def test_message_ipv6():
|
|||
# Test ATYP=0x04 (IPV6)
|
||||
ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344"
|
||||
|
||||
raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
|
||||
out = StringIO()
|
||||
msg = socks.Message.from_file(raw)
|
||||
assert raw.read(2) == "\xBE\xEF"
|
||||
|
@ -73,12 +73,12 @@ def test_message_ipv6():
|
|||
|
||||
|
||||
def test_message_invalid_rsv():
|
||||
raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
|
||||
|
||||
|
||||
def test_message_unknown_atyp():
|
||||
raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
|
||||
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
|
||||
|
||||
m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050)))
|
||||
|
@ -93,4 +93,4 @@ def test_read():
|
|||
|
||||
cs = mock.Mock()
|
||||
cs.read = mock.Mock(side_effect=socket.error)
|
||||
tutils.raises(socks.SocksError, socks._read, cs, 4)
|
||||
tutils.raises(socks.SocksError, socks._read, cs, 4)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import cStringIO
|
||||
import os
|
||||
|
||||
from nose.tools import raises
|
||||
|
@ -170,7 +169,7 @@ class TestFrameHeader:
|
|||
def round(*args, **kwargs):
|
||||
f = websockets.FrameHeader(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
|
||||
assert f == f2
|
||||
round()
|
||||
round(fin=1)
|
||||
|
@ -197,7 +196,7 @@ class TestFrameHeader:
|
|||
def test_funky(self):
|
||||
f = websockets.FrameHeader(masking_key="test", mask=False)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
|
||||
assert not f2.mask
|
||||
|
||||
def test_violations(self):
|
||||
|
@ -221,7 +220,7 @@ class TestFrame:
|
|||
def round(*args, **kwargs):
|
||||
f = websockets.Frame(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
|
||||
f2 = websockets.Frame.from_file(tutils.treader(bytes))
|
||||
assert f == f2
|
||||
round("test")
|
||||
round("test", fin=1)
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
import tempfile, os, shutil
|
||||
import cStringIO
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from libpathod import utils
|
||||
|
||||
from netlib import tcp
|
||||
|
||||
|
||||
def treader(bytes):
|
||||
"""
|
||||
Construct a tcp.Read object from bytes.
|
||||
"""
|
||||
fp = cStringIO.StringIO(bytes)
|
||||
return tcp.Reader(fp)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tmpdir(*args, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue