add generic TCP handler with SSL support, move StateObject into netlib

This commit is contained in:
Maximilian Hils 2014-01-30 18:56:23 +01:00
parent 179c3ae8aa
commit 8544a5ba4b
5 changed files with 114 additions and 135 deletions

View File

@ -8,9 +8,10 @@ import types
import tnetstring, filt, script, utils, encoding, proxy
from email.utils import parsedate_tz, formatdate, mktime_tz
from netlib import odict, http, certutils, wsgi
import controller, version
import controller, version, protocol
import app
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_MISSING = 0
@ -144,86 +145,6 @@ class SetHeaders:
f.request.headers.add(header, value)
class StateObject:
def _get_state(self):
raise NotImplementedError
def _load_state(self, state):
raise NotImplementedError
@classmethod
def _from_state(cls, state):
raise NotImplementedError
def __eq__(self, other):
try:
return self._get_state() == other._get_state()
except AttributeError: # we may compare with something that's not a StateObject
return False
class SimpleStateObject(StateObject):
"""
A StateObject with opionated conventions that tries to keep everything DRY.
Simply put, you agree on a list of attributes and their type.
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
Overriding _get_state or _load_state to add custom adjustments is always possible.
"""
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
"""
An attribute-name -> class-or-type dict containing all attributes that should be serialized
If the attribute is a class, this class must be a subclass of StateObject.
"""
def _get_state(self):
return {attr: self.__get_state_attr(attr, cls)
for attr, cls in self._stateobject_attributes.iteritems()}
def __get_state_attr(self, attr, cls):
"""
helper for _get_state.
returns the value of the given attribute
"""
if getattr(self, attr) is None:
return None
if isinstance(cls, types.ClassType):
return getattr(self, attr)._get_state()
else:
return getattr(self, attr)
def _load_state(self, state):
for attr, cls in self._stateobject_attributes.iteritems():
self.__load_state_attr(attr, cls, state)
def __load_state_attr(self, attr, cls, state):
"""
helper for _load_state.
loads the given attribute from the state.
"""
if state[attr] is not None: # First, catch None as value.
if isinstance(cls, types.ClassType): # Is the attribute a StateObject itself?
# FIXME: assertion doesn't hold because of odict at the moment
# assert issubclass(cls, StateObject)
curr = getattr(self, attr)
if curr: # if the attribute is already present, delegate to the objects ._load_state method.
curr._load_state(state[attr])
else: # otherwise, create a new object.
setattr(self, attr, cls._from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
else:
setattr(self, attr, None)
@classmethod
def _from_state(cls, state):
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
f._load_state(state)
return f
class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
@ -834,7 +755,7 @@ class FlowReader:
v = ".".join(str(i) for i in data["version"])
raise FlowReadError("Incompatible serialized data version: %s"%v)
off = self.fo.tell()
yield Flow._from_state(data)
yield protocol.protocols[data["conntype"]]["flow"]._from_state(data)
except ValueError, v:
# Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '':

View File

@ -12,6 +12,7 @@ class ConnectionTypeChange(Exception):
class ProtocolHandler(object):
def __init__(self, c):
self.c = c
"""@type : libmproxy.proxy.ConnectionHandler"""
def handle_messages(self):
"""
@ -27,13 +28,17 @@ class ProtocolHandler(object):
"""
raise NotImplementedError
from . import http, tcp
from .http import HTTPHandler
protocols = dict(
http = dict(handler=http.HTTPHandler, flow=http.HTTPFlow),
tcp = dict(handler=tcp.TCPHandler),
)
def _handler(conntype, connection_handler):
if conntype == "http":
return HTTPHandler(connection_handler)
if conntype in protocols:
return protocols[conntype]["handler"](connection_handler)
raise NotImplementedError

View File

@ -1,17 +1,11 @@
import Cookie
import Cookie, urllib, urlparse, time, copy
from email.utils import parsedate_tz, formatdate, mktime_tz
import urllib
import urlparse
import time
import copy
from ..flow import SimpleStateObject
from netlib import http, tcp, http_status
from netlib.odict import ODict, ODictCaseless
import netlib.utils
from .. import encoding, utils, version, filt, controller
from ..proxy import ProxyError, ServerConnection, ClientConnection
from netlib import http, tcp, http_status, stateobject, odict
from netlib.odict import ODict, ODictCaseless
from . import ProtocolHandler, ConnectionTypeChange, KILL
import libmproxy.flow
from .. import encoding, utils, version, filt, controller
from ..proxy import ProxyError, ClientConnection, ServerConnection
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_MISSING = 0
@ -57,7 +51,7 @@ class decoded(object):
if self.ce:
self.o.encode(self.ce)
# FIXME: Move out of http
class BackreferenceMixin(object):
"""
If an attribute from the _backrefattr tuple is set,
@ -73,12 +67,10 @@ class BackreferenceMixin(object):
def __setattr__(self, key, value):
super(BackreferenceMixin, self).__setattr__(key, value)
if key in self._backrefattr and value is not None:
# check if there is already a different object set as backref
assert (getattr(value, self._backrefname, self) or self) is self
setattr(value, self._backrefname, self)
# FIXME: Move out of http
class Error(SimpleStateObject):
class Error(stateobject.SimpleStateObject):
"""
An Error.
@ -107,7 +99,7 @@ class Error(SimpleStateObject):
return c
# FIXME: Move out of http
class Flow(SimpleStateObject, BackreferenceMixin):
class Flow(stateobject.SimpleStateObject, BackreferenceMixin):
def __init__(self, conntype, client_conn, server_conn, error):
self.conntype = conntype
self.client_conn = client_conn
@ -167,7 +159,7 @@ class Flow(SimpleStateObject, BackreferenceMixin):
self._backup = None
class HTTPMessage(SimpleStateObject):
class HTTPMessage(stateobject.SimpleStateObject):
def __init__(self):
self.flow = None # Will usually set by backref mixin

57
libmproxy/protocol/tcp.py Normal file
View File

@ -0,0 +1,57 @@
from . import ProtocolHandler
import select, socket
from cStringIO import StringIO
class TCPHandler(ProtocolHandler):
"""
TCPHandler acts as a generic TCP forwarder.
Data will be .log()ed, but not stored any further.
"""
def handle_messages(self):
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
while not self.c.close:
r, _, _ = select.select(conns, [], [], 10)
for rfile in r:
if self.c.client_conn.rfile == rfile:
src, dst = self.c.client_conn, self.c.server_conn
src_str, dst_str = "client", "server"
else:
dst, src = self.c.client_conn, self.c.server_conn
dst_str, src_str = "client", "server"
data = StringIO()
while range(4096):
# Do non-blocking select() to see if there is further data on in the buffer.
r, _, _ = select.select([rfile], [], [], 0)
if len(r):
d = rfile.read(1)
if d == "": # connection closed
break
data.write(d)
"""
OpenSSL Connections have an internal buffer that might contain data altough everything is read
from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer,
so we can read it completely at once.
"""
if src.ssl_established:
data.write(rfile.read(src.connection.pending()))
else: # no data left, but not closed yet
break
data = data.getvalue()
if data == "": # no data received, rfile is closed
self.c.log("Close writing connection to %s" % dst_str)
conns.remove(rfile)
if dst.ssl_established:
dst.connection.shutdown()
else:
dst.connection.shutdown(socket.SHUT_WR)
if len(conns) == 0:
self.c.close = True
break
self.c.log("%s -> %s" % (src_str, dst_str), ["\r\n" + data])
dst.wfile.write(data)
dst.wfile.flush()

View File

@ -1,7 +1,7 @@
import os, socket, time, threading
from OpenSSL import SSL
from netlib import tcp, http, certutils, http_auth
import utils, flow, version, platform, controller
from netlib import tcp, http, certutils, http_auth, stateobject
import utils, version, platform, controller
TRANSPARENT_SSL_PORTS = [443, 8443]
@ -34,7 +34,7 @@ class ProxyConfig:
self.certstore = certutils.CertStore()
class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
def __init__(self, client_connection, address, server):
tcp.BaseHandler.__init__(self, client_connection, address, server)
@ -46,7 +46,8 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
timestamp_start=float,
timestamp_end=float,
timestamp_ssl_setup=float,
# FIXME: Add missing attributes
address=tcp.Address,
clientcert=certutils.SSLCert
)
@classmethod
@ -62,7 +63,7 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
self.timestamp_end = utils.timestamp()
class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
def __init__(self, address):
tcp.TCPClient.__init__(self, address)
@ -78,12 +79,14 @@ class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
timestamp_end=float,
timestamp_tcp_setup=float,
timestamp_ssl_setup=float,
# FIXME: Add missing attributes
address=tcp.Address,
source_address=tcp.Address,
cert=certutils.SSLCert
)
@classmethod
def _from_state(cls, state):
raise NotImplementedError # FIXME
raise NotImplementedError # FIXME
def connect(self):
self.timestamp_start = utils.timestamp()
@ -172,33 +175,34 @@ class ConnectionHandler:
self.determine_conntype()
try:
# Can we already identify the target server and connect to it?
server_address = None
if self.config.forward_proxy:
server_address = self.config.forward_proxy[1:]
else:
if self.config.reverse_proxy:
server_address = self.config.reverse_proxy[1:]
elif self.config.transparent_proxy:
server_address = self.config.transparent_proxy["resolver"].original_addr(
self.client_conn.connection)
if not server_address:
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
self.log("transparent to %s:%s" % server_address)
try:
# Can we already identify the target server and connect to it?
server_address = None
if self.config.forward_proxy:
server_address = self.config.forward_proxy[1:]
else:
if self.config.reverse_proxy:
server_address = self.config.reverse_proxy[1:]
elif self.config.transparent_proxy:
server_address = self.config.transparent_proxy["resolver"].original_addr(
self.client_conn.connection)
if not server_address:
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
self.log("transparent to %s:%s" % server_address)
if server_address:
self.establish_server_connection(server_address)
self._handle_ssl()
if server_address:
self.establish_server_connection(server_address)
self._handle_ssl()
while not self.close:
try:
protocol.handle_messages(self.conntype, self)
except protocol.ConnectionTypeChange:
continue
while not self.close:
try:
protocol.handle_messages(self.conntype, self)
except protocol.ConnectionTypeChange:
continue
# FIXME: Do we want to persist errors?
except (ProxyError, tcp.NetLibError), e:
protocol.handle_error(self.conntype, self, e)
# FIXME: Do we want to persist errors?
except (ProxyError, tcp.NetLibError), e:
protocol.handle_error(self.conntype, self, e)
except Exception, e:
self.log(e.__class__)
import traceback
@ -250,7 +254,7 @@ class ConnectionHandler:
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
"""
# TODO: Implement SSL pass-through handling and change conntype
if self.server_conn.address.host == "ycombinator.com":
if self.server_conn.address.host == "news.ycombinator.com":
self.conntype = "tcp"
if server:
@ -265,8 +269,8 @@ class ConnectionHandler:
handle_sni=self.handle_sni)
def server_reconnect(self, no_ssl=False):
self.log("server reconnect")
had_ssl, sni = self.server_conn.ssl_established, self.sni
self.log("server reconnect (ssl: %s, sni: %s)" % (had_ssl, sni))
self.establish_server_connection(self.server_conn.address)
if had_ssl and not no_ssl:
self.sni = sni