add generic TCP handler with SSL support, move StateObject into netlib
This commit is contained in:
parent
179c3ae8aa
commit
8544a5ba4b
|
@ -8,9 +8,10 @@ import types
|
||||||
import tnetstring, filt, script, utils, encoding, proxy
|
import tnetstring, filt, script, utils, encoding, proxy
|
||||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||||
from netlib import odict, http, certutils, wsgi
|
from netlib import odict, http, certutils, wsgi
|
||||||
import controller, version
|
import controller, version, protocol
|
||||||
import app
|
import app
|
||||||
|
|
||||||
|
|
||||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||||
CONTENT_MISSING = 0
|
CONTENT_MISSING = 0
|
||||||
|
|
||||||
|
@ -144,86 +145,6 @@ class SetHeaders:
|
||||||
f.request.headers.add(header, value)
|
f.request.headers.add(header, value)
|
||||||
|
|
||||||
|
|
||||||
class StateObject:
|
|
||||||
def _get_state(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _load_state(self, state):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_state(cls, state):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
try:
|
|
||||||
return self._get_state() == other._get_state()
|
|
||||||
except AttributeError: # we may compare with something that's not a StateObject
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleStateObject(StateObject):
|
|
||||||
"""
|
|
||||||
A StateObject with opionated conventions that tries to keep everything DRY.
|
|
||||||
|
|
||||||
Simply put, you agree on a list of attributes and their type.
|
|
||||||
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
|
|
||||||
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
|
|
||||||
Overriding _get_state or _load_state to add custom adjustments is always possible.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
|
|
||||||
"""
|
|
||||||
An attribute-name -> class-or-type dict containing all attributes that should be serialized
|
|
||||||
If the attribute is a class, this class must be a subclass of StateObject.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_state(self):
|
|
||||||
return {attr: self.__get_state_attr(attr, cls)
|
|
||||||
for attr, cls in self._stateobject_attributes.iteritems()}
|
|
||||||
|
|
||||||
def __get_state_attr(self, attr, cls):
|
|
||||||
"""
|
|
||||||
helper for _get_state.
|
|
||||||
returns the value of the given attribute
|
|
||||||
"""
|
|
||||||
if getattr(self, attr) is None:
|
|
||||||
return None
|
|
||||||
if isinstance(cls, types.ClassType):
|
|
||||||
return getattr(self, attr)._get_state()
|
|
||||||
else:
|
|
||||||
return getattr(self, attr)
|
|
||||||
|
|
||||||
def _load_state(self, state):
|
|
||||||
for attr, cls in self._stateobject_attributes.iteritems():
|
|
||||||
self.__load_state_attr(attr, cls, state)
|
|
||||||
|
|
||||||
def __load_state_attr(self, attr, cls, state):
|
|
||||||
"""
|
|
||||||
helper for _load_state.
|
|
||||||
loads the given attribute from the state.
|
|
||||||
"""
|
|
||||||
if state[attr] is not None: # First, catch None as value.
|
|
||||||
if isinstance(cls, types.ClassType): # Is the attribute a StateObject itself?
|
|
||||||
# FIXME: assertion doesn't hold because of odict at the moment
|
|
||||||
# assert issubclass(cls, StateObject)
|
|
||||||
curr = getattr(self, attr)
|
|
||||||
if curr: # if the attribute is already present, delegate to the objects ._load_state method.
|
|
||||||
curr._load_state(state[attr])
|
|
||||||
else: # otherwise, create a new object.
|
|
||||||
setattr(self, attr, cls._from_state(state[attr]))
|
|
||||||
else:
|
|
||||||
setattr(self, attr, cls(state[attr]))
|
|
||||||
else:
|
|
||||||
setattr(self, attr, None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_state(cls, state):
|
|
||||||
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
|
|
||||||
f._load_state(state)
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
class ClientPlaybackState:
|
class ClientPlaybackState:
|
||||||
def __init__(self, flows, exit):
|
def __init__(self, flows, exit):
|
||||||
self.flows, self.exit = flows, exit
|
self.flows, self.exit = flows, exit
|
||||||
|
@ -834,7 +755,7 @@ class FlowReader:
|
||||||
v = ".".join(str(i) for i in data["version"])
|
v = ".".join(str(i) for i in data["version"])
|
||||||
raise FlowReadError("Incompatible serialized data version: %s"%v)
|
raise FlowReadError("Incompatible serialized data version: %s"%v)
|
||||||
off = self.fo.tell()
|
off = self.fo.tell()
|
||||||
yield Flow._from_state(data)
|
yield protocol.protocols[data["conntype"]]["flow"]._from_state(data)
|
||||||
except ValueError, v:
|
except ValueError, v:
|
||||||
# Error is due to EOF
|
# Error is due to EOF
|
||||||
if self.fo.tell() == off and self.fo.read() == '':
|
if self.fo.tell() == off and self.fo.read() == '':
|
||||||
|
|
|
@ -12,6 +12,7 @@ class ConnectionTypeChange(Exception):
|
||||||
class ProtocolHandler(object):
|
class ProtocolHandler(object):
|
||||||
def __init__(self, c):
|
def __init__(self, c):
|
||||||
self.c = c
|
self.c = c
|
||||||
|
"""@type : libmproxy.proxy.ConnectionHandler"""
|
||||||
|
|
||||||
def handle_messages(self):
|
def handle_messages(self):
|
||||||
"""
|
"""
|
||||||
|
@ -27,13 +28,17 @@ class ProtocolHandler(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
from . import http, tcp
|
||||||
|
|
||||||
from .http import HTTPHandler
|
protocols = dict(
|
||||||
|
http = dict(handler=http.HTTPHandler, flow=http.HTTPFlow),
|
||||||
|
tcp = dict(handler=tcp.TCPHandler),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _handler(conntype, connection_handler):
|
def _handler(conntype, connection_handler):
|
||||||
if conntype == "http":
|
if conntype in protocols:
|
||||||
return HTTPHandler(connection_handler)
|
return protocols[conntype]["handler"](connection_handler)
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
import Cookie
|
import Cookie, urllib, urlparse, time, copy
|
||||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||||
import urllib
|
|
||||||
import urlparse
|
|
||||||
import time
|
|
||||||
import copy
|
|
||||||
from ..flow import SimpleStateObject
|
|
||||||
from netlib import http, tcp, http_status
|
|
||||||
from netlib.odict import ODict, ODictCaseless
|
|
||||||
import netlib.utils
|
import netlib.utils
|
||||||
from .. import encoding, utils, version, filt, controller
|
from netlib import http, tcp, http_status, stateobject, odict
|
||||||
from ..proxy import ProxyError, ServerConnection, ClientConnection
|
from netlib.odict import ODict, ODictCaseless
|
||||||
from . import ProtocolHandler, ConnectionTypeChange, KILL
|
from . import ProtocolHandler, ConnectionTypeChange, KILL
|
||||||
import libmproxy.flow
|
from .. import encoding, utils, version, filt, controller
|
||||||
|
from ..proxy import ProxyError, ClientConnection, ServerConnection
|
||||||
|
|
||||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||||
CONTENT_MISSING = 0
|
CONTENT_MISSING = 0
|
||||||
|
@ -57,7 +51,7 @@ class decoded(object):
|
||||||
if self.ce:
|
if self.ce:
|
||||||
self.o.encode(self.ce)
|
self.o.encode(self.ce)
|
||||||
|
|
||||||
|
# FIXME: Move out of http
|
||||||
class BackreferenceMixin(object):
|
class BackreferenceMixin(object):
|
||||||
"""
|
"""
|
||||||
If an attribute from the _backrefattr tuple is set,
|
If an attribute from the _backrefattr tuple is set,
|
||||||
|
@ -73,12 +67,10 @@ class BackreferenceMixin(object):
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
super(BackreferenceMixin, self).__setattr__(key, value)
|
super(BackreferenceMixin, self).__setattr__(key, value)
|
||||||
if key in self._backrefattr and value is not None:
|
if key in self._backrefattr and value is not None:
|
||||||
# check if there is already a different object set as backref
|
|
||||||
assert (getattr(value, self._backrefname, self) or self) is self
|
|
||||||
setattr(value, self._backrefname, self)
|
setattr(value, self._backrefname, self)
|
||||||
|
|
||||||
# FIXME: Move out of http
|
# FIXME: Move out of http
|
||||||
class Error(SimpleStateObject):
|
class Error(stateobject.SimpleStateObject):
|
||||||
"""
|
"""
|
||||||
An Error.
|
An Error.
|
||||||
|
|
||||||
|
@ -107,7 +99,7 @@ class Error(SimpleStateObject):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
# FIXME: Move out of http
|
# FIXME: Move out of http
|
||||||
class Flow(SimpleStateObject, BackreferenceMixin):
|
class Flow(stateobject.SimpleStateObject, BackreferenceMixin):
|
||||||
def __init__(self, conntype, client_conn, server_conn, error):
|
def __init__(self, conntype, client_conn, server_conn, error):
|
||||||
self.conntype = conntype
|
self.conntype = conntype
|
||||||
self.client_conn = client_conn
|
self.client_conn = client_conn
|
||||||
|
@ -167,7 +159,7 @@ class Flow(SimpleStateObject, BackreferenceMixin):
|
||||||
self._backup = None
|
self._backup = None
|
||||||
|
|
||||||
|
|
||||||
class HTTPMessage(SimpleStateObject):
|
class HTTPMessage(stateobject.SimpleStateObject):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.flow = None # Will usually set by backref mixin
|
self.flow = None # Will usually set by backref mixin
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
from . import ProtocolHandler
|
||||||
|
import select, socket
|
||||||
|
from cStringIO import StringIO
|
||||||
|
|
||||||
|
|
||||||
|
class TCPHandler(ProtocolHandler):
|
||||||
|
"""
|
||||||
|
TCPHandler acts as a generic TCP forwarder.
|
||||||
|
Data will be .log()ed, but not stored any further.
|
||||||
|
"""
|
||||||
|
def handle_messages(self):
|
||||||
|
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
|
||||||
|
while not self.c.close:
|
||||||
|
r, _, _ = select.select(conns, [], [], 10)
|
||||||
|
for rfile in r:
|
||||||
|
if self.c.client_conn.rfile == rfile:
|
||||||
|
src, dst = self.c.client_conn, self.c.server_conn
|
||||||
|
src_str, dst_str = "client", "server"
|
||||||
|
else:
|
||||||
|
dst, src = self.c.client_conn, self.c.server_conn
|
||||||
|
dst_str, src_str = "client", "server"
|
||||||
|
|
||||||
|
data = StringIO()
|
||||||
|
while range(4096):
|
||||||
|
# Do non-blocking select() to see if there is further data on in the buffer.
|
||||||
|
r, _, _ = select.select([rfile], [], [], 0)
|
||||||
|
if len(r):
|
||||||
|
d = rfile.read(1)
|
||||||
|
if d == "": # connection closed
|
||||||
|
break
|
||||||
|
data.write(d)
|
||||||
|
|
||||||
|
"""
|
||||||
|
OpenSSL Connections have an internal buffer that might contain data altough everything is read
|
||||||
|
from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer,
|
||||||
|
so we can read it completely at once.
|
||||||
|
"""
|
||||||
|
if src.ssl_established:
|
||||||
|
data.write(rfile.read(src.connection.pending()))
|
||||||
|
else: # no data left, but not closed yet
|
||||||
|
break
|
||||||
|
data = data.getvalue()
|
||||||
|
|
||||||
|
if data == "": # no data received, rfile is closed
|
||||||
|
self.c.log("Close writing connection to %s" % dst_str)
|
||||||
|
conns.remove(rfile)
|
||||||
|
if dst.ssl_established:
|
||||||
|
dst.connection.shutdown()
|
||||||
|
else:
|
||||||
|
dst.connection.shutdown(socket.SHUT_WR)
|
||||||
|
if len(conns) == 0:
|
||||||
|
self.c.close = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.c.log("%s -> %s" % (src_str, dst_str), ["\r\n" + data])
|
||||||
|
dst.wfile.write(data)
|
||||||
|
dst.wfile.flush()
|
|
@ -1,7 +1,7 @@
|
||||||
import os, socket, time, threading
|
import os, socket, time, threading
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from netlib import tcp, http, certutils, http_auth
|
from netlib import tcp, http, certutils, http_auth, stateobject
|
||||||
import utils, flow, version, platform, controller
|
import utils, version, platform, controller
|
||||||
|
|
||||||
|
|
||||||
TRANSPARENT_SSL_PORTS = [443, 8443]
|
TRANSPARENT_SSL_PORTS = [443, 8443]
|
||||||
|
@ -34,7 +34,7 @@ class ProxyConfig:
|
||||||
self.certstore = certutils.CertStore()
|
self.certstore = certutils.CertStore()
|
||||||
|
|
||||||
|
|
||||||
class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
|
||||||
def __init__(self, client_connection, address, server):
|
def __init__(self, client_connection, address, server):
|
||||||
tcp.BaseHandler.__init__(self, client_connection, address, server)
|
tcp.BaseHandler.__init__(self, client_connection, address, server)
|
||||||
|
|
||||||
|
@ -46,7 +46,8 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
||||||
timestamp_start=float,
|
timestamp_start=float,
|
||||||
timestamp_end=float,
|
timestamp_end=float,
|
||||||
timestamp_ssl_setup=float,
|
timestamp_ssl_setup=float,
|
||||||
# FIXME: Add missing attributes
|
address=tcp.Address,
|
||||||
|
clientcert=certutils.SSLCert
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -62,7 +63,7 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
||||||
self.timestamp_end = utils.timestamp()
|
self.timestamp_end = utils.timestamp()
|
||||||
|
|
||||||
|
|
||||||
class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
|
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
tcp.TCPClient.__init__(self, address)
|
tcp.TCPClient.__init__(self, address)
|
||||||
|
|
||||||
|
@ -78,12 +79,14 @@ class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
|
||||||
timestamp_end=float,
|
timestamp_end=float,
|
||||||
timestamp_tcp_setup=float,
|
timestamp_tcp_setup=float,
|
||||||
timestamp_ssl_setup=float,
|
timestamp_ssl_setup=float,
|
||||||
# FIXME: Add missing attributes
|
address=tcp.Address,
|
||||||
|
source_address=tcp.Address,
|
||||||
|
cert=certutils.SSLCert
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_state(cls, state):
|
def _from_state(cls, state):
|
||||||
raise NotImplementedError # FIXME
|
raise NotImplementedError # FIXME
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.timestamp_start = utils.timestamp()
|
self.timestamp_start = utils.timestamp()
|
||||||
|
@ -172,33 +175,34 @@ class ConnectionHandler:
|
||||||
self.determine_conntype()
|
self.determine_conntype()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Can we already identify the target server and connect to it?
|
try:
|
||||||
server_address = None
|
# Can we already identify the target server and connect to it?
|
||||||
if self.config.forward_proxy:
|
server_address = None
|
||||||
server_address = self.config.forward_proxy[1:]
|
if self.config.forward_proxy:
|
||||||
else:
|
server_address = self.config.forward_proxy[1:]
|
||||||
if self.config.reverse_proxy:
|
else:
|
||||||
server_address = self.config.reverse_proxy[1:]
|
if self.config.reverse_proxy:
|
||||||
elif self.config.transparent_proxy:
|
server_address = self.config.reverse_proxy[1:]
|
||||||
server_address = self.config.transparent_proxy["resolver"].original_addr(
|
elif self.config.transparent_proxy:
|
||||||
self.client_conn.connection)
|
server_address = self.config.transparent_proxy["resolver"].original_addr(
|
||||||
if not server_address:
|
self.client_conn.connection)
|
||||||
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
|
if not server_address:
|
||||||
self.log("transparent to %s:%s" % server_address)
|
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
|
||||||
|
self.log("transparent to %s:%s" % server_address)
|
||||||
|
|
||||||
if server_address:
|
if server_address:
|
||||||
self.establish_server_connection(server_address)
|
self.establish_server_connection(server_address)
|
||||||
self._handle_ssl()
|
self._handle_ssl()
|
||||||
|
|
||||||
while not self.close:
|
while not self.close:
|
||||||
try:
|
try:
|
||||||
protocol.handle_messages(self.conntype, self)
|
protocol.handle_messages(self.conntype, self)
|
||||||
except protocol.ConnectionTypeChange:
|
except protocol.ConnectionTypeChange:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXME: Do we want to persist errors?
|
# FIXME: Do we want to persist errors?
|
||||||
except (ProxyError, tcp.NetLibError), e:
|
except (ProxyError, tcp.NetLibError), e:
|
||||||
protocol.handle_error(self.conntype, self, e)
|
protocol.handle_error(self.conntype, self, e)
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
self.log(e.__class__)
|
self.log(e.__class__)
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -250,7 +254,7 @@ class ConnectionHandler:
|
||||||
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
||||||
"""
|
"""
|
||||||
# TODO: Implement SSL pass-through handling and change conntype
|
# TODO: Implement SSL pass-through handling and change conntype
|
||||||
if self.server_conn.address.host == "ycombinator.com":
|
if self.server_conn.address.host == "news.ycombinator.com":
|
||||||
self.conntype = "tcp"
|
self.conntype = "tcp"
|
||||||
|
|
||||||
if server:
|
if server:
|
||||||
|
@ -265,8 +269,8 @@ class ConnectionHandler:
|
||||||
handle_sni=self.handle_sni)
|
handle_sni=self.handle_sni)
|
||||||
|
|
||||||
def server_reconnect(self, no_ssl=False):
|
def server_reconnect(self, no_ssl=False):
|
||||||
self.log("server reconnect")
|
|
||||||
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
||||||
|
self.log("server reconnect (ssl: %s, sni: %s)" % (had_ssl, sni))
|
||||||
self.establish_server_connection(self.server_conn.address)
|
self.establish_server_connection(self.server_conn.address)
|
||||||
if had_ssl and not no_ssl:
|
if had_ssl and not no_ssl:
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
|
|
Loading…
Reference in New Issue