add generic TCP handler with SSL support, move StateObject into netlib

This commit is contained in:
Maximilian Hils 2014-01-30 18:56:23 +01:00
parent 179c3ae8aa
commit 8544a5ba4b
5 changed files with 114 additions and 135 deletions

View File

@ -8,9 +8,10 @@ import types
import tnetstring, filt, script, utils, encoding, proxy import tnetstring, filt, script, utils, encoding, proxy
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
from netlib import odict, http, certutils, wsgi from netlib import odict, http, certutils, wsgi
import controller, version import controller, version, protocol
import app import app
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_MISSING = 0 CONTENT_MISSING = 0
@ -144,86 +145,6 @@ class SetHeaders:
f.request.headers.add(header, value) f.request.headers.add(header, value)
class StateObject:
def _get_state(self):
raise NotImplementedError
def _load_state(self, state):
raise NotImplementedError
@classmethod
def _from_state(cls, state):
raise NotImplementedError
def __eq__(self, other):
try:
return self._get_state() == other._get_state()
except AttributeError: # we may compare with something that's not a StateObject
return False
class SimpleStateObject(StateObject):
"""
A StateObject with opionated conventions that tries to keep everything DRY.
Simply put, you agree on a list of attributes and their type.
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
Overriding _get_state or _load_state to add custom adjustments is always possible.
"""
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
"""
An attribute-name -> class-or-type dict containing all attributes that should be serialized
If the attribute is a class, this class must be a subclass of StateObject.
"""
def _get_state(self):
return {attr: self.__get_state_attr(attr, cls)
for attr, cls in self._stateobject_attributes.iteritems()}
def __get_state_attr(self, attr, cls):
"""
helper for _get_state.
returns the value of the given attribute
"""
if getattr(self, attr) is None:
return None
if isinstance(cls, types.ClassType):
return getattr(self, attr)._get_state()
else:
return getattr(self, attr)
def _load_state(self, state):
for attr, cls in self._stateobject_attributes.iteritems():
self.__load_state_attr(attr, cls, state)
def __load_state_attr(self, attr, cls, state):
"""
helper for _load_state.
loads the given attribute from the state.
"""
if state[attr] is not None: # First, catch None as value.
if isinstance(cls, types.ClassType): # Is the attribute a StateObject itself?
# FIXME: assertion doesn't hold because of odict at the moment
# assert issubclass(cls, StateObject)
curr = getattr(self, attr)
if curr: # if the attribute is already present, delegate to the objects ._load_state method.
curr._load_state(state[attr])
else: # otherwise, create a new object.
setattr(self, attr, cls._from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
else:
setattr(self, attr, None)
@classmethod
def _from_state(cls, state):
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
f._load_state(state)
return f
class ClientPlaybackState: class ClientPlaybackState:
def __init__(self, flows, exit): def __init__(self, flows, exit):
self.flows, self.exit = flows, exit self.flows, self.exit = flows, exit
@ -834,7 +755,7 @@ class FlowReader:
v = ".".join(str(i) for i in data["version"]) v = ".".join(str(i) for i in data["version"])
raise FlowReadError("Incompatible serialized data version: %s"%v) raise FlowReadError("Incompatible serialized data version: %s"%v)
off = self.fo.tell() off = self.fo.tell()
yield Flow._from_state(data) yield protocol.protocols[data["conntype"]]["flow"]._from_state(data)
except ValueError, v: except ValueError, v:
# Error is due to EOF # Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '': if self.fo.tell() == off and self.fo.read() == '':

View File

@ -12,6 +12,7 @@ class ConnectionTypeChange(Exception):
class ProtocolHandler(object): class ProtocolHandler(object):
def __init__(self, c): def __init__(self, c):
self.c = c self.c = c
"""@type : libmproxy.proxy.ConnectionHandler"""
def handle_messages(self): def handle_messages(self):
""" """
@ -27,13 +28,17 @@ class ProtocolHandler(object):
""" """
raise NotImplementedError raise NotImplementedError
from . import http, tcp
from .http import HTTPHandler protocols = dict(
http = dict(handler=http.HTTPHandler, flow=http.HTTPFlow),
tcp = dict(handler=tcp.TCPHandler),
)
def _handler(conntype, connection_handler): def _handler(conntype, connection_handler):
if conntype == "http": if conntype in protocols:
return HTTPHandler(connection_handler) return protocols[conntype]["handler"](connection_handler)
raise NotImplementedError raise NotImplementedError

View File

@ -1,17 +1,11 @@
import Cookie import Cookie, urllib, urlparse, time, copy
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import urllib
import urlparse
import time
import copy
from ..flow import SimpleStateObject
from netlib import http, tcp, http_status
from netlib.odict import ODict, ODictCaseless
import netlib.utils import netlib.utils
from .. import encoding, utils, version, filt, controller from netlib import http, tcp, http_status, stateobject, odict
from ..proxy import ProxyError, ServerConnection, ClientConnection from netlib.odict import ODict, ODictCaseless
from . import ProtocolHandler, ConnectionTypeChange, KILL from . import ProtocolHandler, ConnectionTypeChange, KILL
import libmproxy.flow from .. import encoding, utils, version, filt, controller
from ..proxy import ProxyError, ClientConnection, ServerConnection
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_MISSING = 0 CONTENT_MISSING = 0
@ -57,7 +51,7 @@ class decoded(object):
if self.ce: if self.ce:
self.o.encode(self.ce) self.o.encode(self.ce)
# FIXME: Move out of http
class BackreferenceMixin(object): class BackreferenceMixin(object):
""" """
If an attribute from the _backrefattr tuple is set, If an attribute from the _backrefattr tuple is set,
@ -73,12 +67,10 @@ class BackreferenceMixin(object):
def __setattr__(self, key, value): def __setattr__(self, key, value):
super(BackreferenceMixin, self).__setattr__(key, value) super(BackreferenceMixin, self).__setattr__(key, value)
if key in self._backrefattr and value is not None: if key in self._backrefattr and value is not None:
# check if there is already a different object set as backref
assert (getattr(value, self._backrefname, self) or self) is self
setattr(value, self._backrefname, self) setattr(value, self._backrefname, self)
# FIXME: Move out of http # FIXME: Move out of http
class Error(SimpleStateObject): class Error(stateobject.SimpleStateObject):
""" """
An Error. An Error.
@ -107,7 +99,7 @@ class Error(SimpleStateObject):
return c return c
# FIXME: Move out of http # FIXME: Move out of http
class Flow(SimpleStateObject, BackreferenceMixin): class Flow(stateobject.SimpleStateObject, BackreferenceMixin):
def __init__(self, conntype, client_conn, server_conn, error): def __init__(self, conntype, client_conn, server_conn, error):
self.conntype = conntype self.conntype = conntype
self.client_conn = client_conn self.client_conn = client_conn
@ -167,7 +159,7 @@ class Flow(SimpleStateObject, BackreferenceMixin):
self._backup = None self._backup = None
class HTTPMessage(SimpleStateObject): class HTTPMessage(stateobject.SimpleStateObject):
def __init__(self): def __init__(self):
self.flow = None # Will usually set by backref mixin self.flow = None # Will usually set by backref mixin

57
libmproxy/protocol/tcp.py Normal file
View File

@ -0,0 +1,57 @@
from . import ProtocolHandler
import select, socket
from cStringIO import StringIO
class TCPHandler(ProtocolHandler):
"""
TCPHandler acts as a generic TCP forwarder.
Data will be .log()ed, but not stored any further.
"""
def handle_messages(self):
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
while not self.c.close:
r, _, _ = select.select(conns, [], [], 10)
for rfile in r:
if self.c.client_conn.rfile == rfile:
src, dst = self.c.client_conn, self.c.server_conn
src_str, dst_str = "client", "server"
else:
dst, src = self.c.client_conn, self.c.server_conn
dst_str, src_str = "client", "server"
data = StringIO()
while range(4096):
# Do non-blocking select() to see if there is further data on in the buffer.
r, _, _ = select.select([rfile], [], [], 0)
if len(r):
d = rfile.read(1)
if d == "": # connection closed
break
data.write(d)
"""
OpenSSL Connections have an internal buffer that might contain data altough everything is read
from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer,
so we can read it completely at once.
"""
if src.ssl_established:
data.write(rfile.read(src.connection.pending()))
else: # no data left, but not closed yet
break
data = data.getvalue()
if data == "": # no data received, rfile is closed
self.c.log("Close writing connection to %s" % dst_str)
conns.remove(rfile)
if dst.ssl_established:
dst.connection.shutdown()
else:
dst.connection.shutdown(socket.SHUT_WR)
if len(conns) == 0:
self.c.close = True
break
self.c.log("%s -> %s" % (src_str, dst_str), ["\r\n" + data])
dst.wfile.write(data)
dst.wfile.flush()

View File

@ -1,7 +1,7 @@
import os, socket, time, threading import os, socket, time, threading
from OpenSSL import SSL from OpenSSL import SSL
from netlib import tcp, http, certutils, http_auth from netlib import tcp, http, certutils, http_auth, stateobject
import utils, flow, version, platform, controller import utils, version, platform, controller
TRANSPARENT_SSL_PORTS = [443, 8443] TRANSPARENT_SSL_PORTS = [443, 8443]
@ -34,7 +34,7 @@ class ProxyConfig:
self.certstore = certutils.CertStore() self.certstore = certutils.CertStore()
class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject): class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
def __init__(self, client_connection, address, server): def __init__(self, client_connection, address, server):
tcp.BaseHandler.__init__(self, client_connection, address, server) tcp.BaseHandler.__init__(self, client_connection, address, server)
@ -46,7 +46,8 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
timestamp_start=float, timestamp_start=float,
timestamp_end=float, timestamp_end=float,
timestamp_ssl_setup=float, timestamp_ssl_setup=float,
# FIXME: Add missing attributes address=tcp.Address,
clientcert=certutils.SSLCert
) )
@classmethod @classmethod
@ -62,7 +63,7 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
self.timestamp_end = utils.timestamp() self.timestamp_end = utils.timestamp()
class ServerConnection(tcp.TCPClient, flow.SimpleStateObject): class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
def __init__(self, address): def __init__(self, address):
tcp.TCPClient.__init__(self, address) tcp.TCPClient.__init__(self, address)
@ -78,12 +79,14 @@ class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
timestamp_end=float, timestamp_end=float,
timestamp_tcp_setup=float, timestamp_tcp_setup=float,
timestamp_ssl_setup=float, timestamp_ssl_setup=float,
# FIXME: Add missing attributes address=tcp.Address,
source_address=tcp.Address,
cert=certutils.SSLCert
) )
@classmethod @classmethod
def _from_state(cls, state): def _from_state(cls, state):
raise NotImplementedError # FIXME raise NotImplementedError # FIXME
def connect(self): def connect(self):
self.timestamp_start = utils.timestamp() self.timestamp_start = utils.timestamp()
@ -172,33 +175,34 @@ class ConnectionHandler:
self.determine_conntype() self.determine_conntype()
try: try:
# Can we already identify the target server and connect to it? try:
server_address = None # Can we already identify the target server and connect to it?
if self.config.forward_proxy: server_address = None
server_address = self.config.forward_proxy[1:] if self.config.forward_proxy:
else: server_address = self.config.forward_proxy[1:]
if self.config.reverse_proxy: else:
server_address = self.config.reverse_proxy[1:] if self.config.reverse_proxy:
elif self.config.transparent_proxy: server_address = self.config.reverse_proxy[1:]
server_address = self.config.transparent_proxy["resolver"].original_addr( elif self.config.transparent_proxy:
self.client_conn.connection) server_address = self.config.transparent_proxy["resolver"].original_addr(
if not server_address: self.client_conn.connection)
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") if not server_address:
self.log("transparent to %s:%s" % server_address) raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
self.log("transparent to %s:%s" % server_address)
if server_address: if server_address:
self.establish_server_connection(server_address) self.establish_server_connection(server_address)
self._handle_ssl() self._handle_ssl()
while not self.close: while not self.close:
try: try:
protocol.handle_messages(self.conntype, self) protocol.handle_messages(self.conntype, self)
except protocol.ConnectionTypeChange: except protocol.ConnectionTypeChange:
continue continue
# FIXME: Do we want to persist errors? # FIXME: Do we want to persist errors?
except (ProxyError, tcp.NetLibError), e: except (ProxyError, tcp.NetLibError), e:
protocol.handle_error(self.conntype, self, e) protocol.handle_error(self.conntype, self, e)
except Exception, e: except Exception, e:
self.log(e.__class__) self.log(e.__class__)
import traceback import traceback
@ -250,7 +254,7 @@ class ConnectionHandler:
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
""" """
# TODO: Implement SSL pass-through handling and change conntype # TODO: Implement SSL pass-through handling and change conntype
if self.server_conn.address.host == "ycombinator.com": if self.server_conn.address.host == "news.ycombinator.com":
self.conntype = "tcp" self.conntype = "tcp"
if server: if server:
@ -265,8 +269,8 @@ class ConnectionHandler:
handle_sni=self.handle_sni) handle_sni=self.handle_sni)
def server_reconnect(self, no_ssl=False): def server_reconnect(self, no_ssl=False):
self.log("server reconnect")
had_ssl, sni = self.server_conn.ssl_established, self.sni had_ssl, sni = self.server_conn.ssl_established, self.sni
self.log("server reconnect (ssl: %s, sni: %s)" % (had_ssl, sni))
self.establish_server_connection(self.server_conn.address) self.establish_server_connection(self.server_conn.address)
if had_ssl and not no_ssl: if had_ssl and not no_ssl:
self.sni = sni self.sni = sni