Add tcp.Reader.safe_read, use it in socks and websockets

safe_read is guaranteed to raise or return a byte string of the
requested length. It's particularly useful for implementing binary
protocols.
This commit is contained in:
Aldo Cortesi 2015-05-05 10:47:02 +12:00
parent 08b2e2a6a9
commit f2bc58cdd2
7 changed files with 80 additions and 62 deletions

View File

@ -52,20 +52,6 @@ METHOD = utils.BiDi(
)
def _read(f, n):
try:
d = f.read(n)
if len(d) == n:
return d
else:
raise SocksError(
REP.GENERAL_SOCKS_SERVER_FAILURE,
"Incomplete Read"
)
except socket.error as e:
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e))
class ClientGreeting(object):
__slots__ = ("ver", "methods")
@ -75,9 +61,9 @@ class ClientGreeting(object):
@classmethod
def from_file(cls, f):
ver, nmethods = struct.unpack("!BB", _read(f, 2))
ver, nmethods = struct.unpack("!BB", f.safe_read(2))
methods = array.array("B")
methods.fromstring(_read(f, nmethods))
methods.fromstring(f.safe_read(nmethods))
return cls(ver, methods)
def to_file(self, f):
@ -94,7 +80,7 @@ class ServerGreeting(object):
@classmethod
def from_file(cls, f):
ver, method = struct.unpack("!BB", _read(f, 2))
ver, method = struct.unpack("!BB", f.safe_read(2))
return cls(ver, method)
def to_file(self, f):
@ -112,27 +98,27 @@ class Message(object):
@classmethod
def from_file(cls, f):
ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4))
ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))
if rsv != 0x00:
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE,
"Socks Request: Invalid reserved byte: %s" % rsv)
if atyp == ATYP.IPV4_ADDRESS:
# We use tnoa here as ntop is not commonly available on Windows.
host = socket.inet_ntoa(_read(f, 4))
host = socket.inet_ntoa(f.safe_read(4))
use_ipv6 = False
elif atyp == ATYP.IPV6_ADDRESS:
host = socket.inet_ntop(socket.AF_INET6, _read(f, 16))
host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16))
use_ipv6 = True
elif atyp == ATYP.DOMAINNAME:
length, = struct.unpack("!B", _read(f, 1))
host = _read(f, length)
length, = struct.unpack("!B", f.safe_read(1))
host = f.safe_read(length)
use_ipv6 = False
else:
raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
"Socks Request: Unknown ATYP: %s" % atyp)
port, = struct.unpack("!H", _read(f, 2))
port, = struct.unpack("!H", f.safe_read(2))
addr = tcp.Address((host, port), use_ipv6=use_ipv6)
return cls(ver, msg, atyp, addr)

View File

@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3
class NetLibError(Exception): pass
class NetLibDisconnect(NetLibError): pass
class NetLibIncomplete(NetLibError): pass
class NetLibTimeout(NetLibError): pass
class NetLibSSLError(NetLibError): pass
@ -195,10 +196,23 @@ class Reader(_FileLike):
break
return result
def safe_read(self, length):
"""
Like .read, but is guaranteed to either return length bytes, or
raise an exception.
"""
result = self.read(length)
if length != -1 and len(result) != length:
raise NetLibIncomplete(
"Expected %s bytes, got %s"%(length, len(result))
)
return result
class Address(object):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
This class wraps an IPv4/IPv6 tuple to provide named attributes and
ipv6 information.
"""
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
@ -247,22 +261,28 @@ def close_socket(sock):
"""
try:
# We already indicate that we close our end.
sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
# may raise "Transport endpoint is not connected" on Linux
sock.shutdown(socket.SHUT_WR)
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
# pending readable data could lead to an immediate RST being sent (which is the case on Windows).
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
# readable data could lead to an immediate RST being sent (which is the
# case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
# This in turn results in the following issue: If we send an error page to the client and then close the socket,
# the RST may be received by the client before the error page and the users sees a connection error rather than
# the error page. Thus, we try to empty the read buffer on Windows first.
# (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
# This in turn results in the following issue: If we send an error page
# to the client and then close the socket, the RST may be received by
# the client before the error page and the users sees a connection
# error rather than the error page. Thus, we try to empty the read
# buffer on Windows first. (see
# https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
#
if os.name == "nt": # pragma: no cover
# We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
# Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely.
# As a workaround, we set a timeout here even if we are in blocking mode.
# We cannot rely on the shutdown()-followed-by-read()-eof technique
# proposed by the page above: Some remote machines just don't send
# a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely. As a workaround, we set a timeout
# here even if we are in blocking mode.
sock.settimeout(sock.gettimeout() or 20)
# limit at a megabyte so that we don't read infinitely
@ -292,10 +312,10 @@ class _Connection(object):
def finish(self):
self.finished = True
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
# Closing the socket is not our task, therefore we don't call close then.
# Closing the socket is not our task, therefore we don't call close
# then.
if type(self.connection) != SSL.Connection:
if not getattr(self.wfile, "closed", False):
try:

View File

@ -5,7 +5,7 @@ import os
import struct
import io
from . import utils, odict
from . import utils, odict, tcp
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@ -217,8 +217,8 @@ class FrameHeader:
"""
read a websockets frame header
"""
first_byte = utils.bytes_to_int(fp.read(1))
second_byte = utils.bytes_to_int(fp.read(1))
first_byte = utils.bytes_to_int(fp.safe_read(1))
second_byte = utils.bytes_to_int(fp.safe_read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
@ -235,13 +235,13 @@ class FrameHeader:
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length = utils.bytes_to_int(fp.read(2))
payload_length = utils.bytes_to_int(fp.safe_read(2))
elif length_code == 127:
payload_length = utils.bytes_to_int(fp.read(8))
payload_length = utils.bytes_to_int(fp.safe_read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.read(4)
masking_key = fp.safe_read(4)
else:
masking_key = None
@ -319,7 +319,7 @@ class Frame(object):
Construct a websocket frame from an in-memory bytestring
to construct a frame from a stream of bytes, use from_file() directly
"""
return cls.from_file(io.BytesIO(bytestring))
return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
def human_readable(self):
hdr = self.header.human_readable()
@ -351,7 +351,7 @@ class Frame(object):
stream or a disk or an in memory stream reader
"""
header = FrameHeader.from_file(fp)
payload = fp.read(header.payload_length)
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = Masker(header.masking_key)(payload)

View File

@ -91,7 +91,7 @@ def test_read_http_body_request():
def test_read_http_body_response():
h = odict.ODictCaseless()
s = cStringIO.StringIO("testing")
s = tcp.Reader(cStringIO.StringIO("testing"))
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
@ -135,11 +135,11 @@ def test_read_http_body():
# test no content length: limit > actual content
h = odict.ODictCaseless()
s = cStringIO.StringIO("testing")
s = tcp.Reader(cStringIO.StringIO("testing"))
assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
s = cStringIO.StringIO("testing")
s = tcp.Reader(cStringIO.StringIO("testing"))
tutils.raises(
http.HttpError,
http.read_http_body,
@ -149,7 +149,7 @@ def test_read_http_body():
# test chunked
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n"))
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"

View File

@ -7,7 +7,7 @@ import tutils
def test_client_greeting():
raw = StringIO("\x05\x02\x00\xBE\xEF")
raw = tutils.treader("\x05\x02\x00\xBE\xEF")
out = StringIO()
msg = socks.ClientGreeting.from_file(raw)
msg.to_file(out)
@ -20,7 +20,7 @@ def test_client_greeting():
def test_server_greeting():
raw = StringIO("\x05\x02")
raw = tutils.treader("\x05\x02")
out = StringIO()
msg = socks.ServerGreeting.from_file(raw)
msg.to_file(out)
@ -31,7 +31,7 @@ def test_server_greeting():
def test_message():
raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@ -46,7 +46,7 @@ def test_message():
def test_message_ipv4():
# Test ATYP=0x01 (IPV4)
raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@ -62,7 +62,7 @@ def test_message_ipv6():
# Test ATYP=0x04 (IPV6)
ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344"
raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF")
out = StringIO()
msg = socks.Message.from_file(raw)
assert raw.read(2) == "\xBE\xEF"
@ -73,12 +73,12 @@ def test_message_ipv6():
def test_message_invalid_rsv():
raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
def test_message_unknown_atyp():
raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050)))
@ -93,4 +93,4 @@ def test_read():
cs = mock.Mock()
cs.read = mock.Mock(side_effect=socket.error)
tutils.raises(socks.SocksError, socks._read, cs, 4)
tutils.raises(socks.SocksError, socks._read, cs, 4)

View File

@ -1,4 +1,3 @@
import cStringIO
import os
from nose.tools import raises
@ -170,7 +169,7 @@ class TestFrameHeader:
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert f == f2
round()
round(fin=1)
@ -197,7 +196,7 @@ class TestFrameHeader:
def test_funky(self):
f = websockets.FrameHeader(masking_key="test", mask=False)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert not f2.mask
def test_violations(self):
@ -221,7 +220,7 @@ class TestFrame:
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
f2 = websockets.Frame.from_file(tutils.treader(bytes))
assert f == f2
round("test")
round("test", fin=1)

View File

@ -1,7 +1,20 @@
import tempfile, os, shutil
import cStringIO
import tempfile
import os
import shutil
from contextlib import contextmanager
from libpathod import utils
from netlib import tcp
def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
fp = cStringIO.StringIO(bytes)
return tcp.Reader(fp)
@contextmanager
def tmpdir(*args, **kwargs):