add generic TCP handler with SSL support, move StateObject into netlib
This commit is contained in:
parent
179c3ae8aa
commit
8544a5ba4b
|
@ -8,9 +8,10 @@ import types
|
|||
import tnetstring, filt, script, utils, encoding, proxy
|
||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||
from netlib import odict, http, certutils, wsgi
|
||||
import controller, version
|
||||
import controller, version, protocol
|
||||
import app
|
||||
|
||||
|
||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
CONTENT_MISSING = 0
|
||||
|
||||
|
@ -144,86 +145,6 @@ class SetHeaders:
|
|||
f.request.headers.add(header, value)
|
||||
|
||||
|
||||
class StateObject:
|
||||
def _get_state(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_state(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
return self._get_state() == other._get_state()
|
||||
except AttributeError: # we may compare with something that's not a StateObject
|
||||
return False
|
||||
|
||||
|
||||
class SimpleStateObject(StateObject):
|
||||
"""
|
||||
A StateObject with opionated conventions that tries to keep everything DRY.
|
||||
|
||||
Simply put, you agree on a list of attributes and their type.
|
||||
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
|
||||
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
|
||||
Overriding _get_state or _load_state to add custom adjustments is always possible.
|
||||
"""
|
||||
|
||||
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
|
||||
"""
|
||||
An attribute-name -> class-or-type dict containing all attributes that should be serialized
|
||||
If the attribute is a class, this class must be a subclass of StateObject.
|
||||
"""
|
||||
|
||||
def _get_state(self):
|
||||
return {attr: self.__get_state_attr(attr, cls)
|
||||
for attr, cls in self._stateobject_attributes.iteritems()}
|
||||
|
||||
def __get_state_attr(self, attr, cls):
|
||||
"""
|
||||
helper for _get_state.
|
||||
returns the value of the given attribute
|
||||
"""
|
||||
if getattr(self, attr) is None:
|
||||
return None
|
||||
if isinstance(cls, types.ClassType):
|
||||
return getattr(self, attr)._get_state()
|
||||
else:
|
||||
return getattr(self, attr)
|
||||
|
||||
def _load_state(self, state):
|
||||
for attr, cls in self._stateobject_attributes.iteritems():
|
||||
self.__load_state_attr(attr, cls, state)
|
||||
|
||||
def __load_state_attr(self, attr, cls, state):
|
||||
"""
|
||||
helper for _load_state.
|
||||
loads the given attribute from the state.
|
||||
"""
|
||||
if state[attr] is not None: # First, catch None as value.
|
||||
if isinstance(cls, types.ClassType): # Is the attribute a StateObject itself?
|
||||
# FIXME: assertion doesn't hold because of odict at the moment
|
||||
# assert issubclass(cls, StateObject)
|
||||
curr = getattr(self, attr)
|
||||
if curr: # if the attribute is already present, delegate to the objects ._load_state method.
|
||||
curr._load_state(state[attr])
|
||||
else: # otherwise, create a new object.
|
||||
setattr(self, attr, cls._from_state(state[attr]))
|
||||
else:
|
||||
setattr(self, attr, cls(state[attr]))
|
||||
else:
|
||||
setattr(self, attr, None)
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
|
||||
f._load_state(state)
|
||||
return f
|
||||
|
||||
|
||||
class ClientPlaybackState:
|
||||
def __init__(self, flows, exit):
|
||||
self.flows, self.exit = flows, exit
|
||||
|
@ -834,7 +755,7 @@ class FlowReader:
|
|||
v = ".".join(str(i) for i in data["version"])
|
||||
raise FlowReadError("Incompatible serialized data version: %s"%v)
|
||||
off = self.fo.tell()
|
||||
yield Flow._from_state(data)
|
||||
yield protocol.protocols[data["conntype"]]["flow"]._from_state(data)
|
||||
except ValueError, v:
|
||||
# Error is due to EOF
|
||||
if self.fo.tell() == off and self.fo.read() == '':
|
||||
|
|
|
@ -12,6 +12,7 @@ class ConnectionTypeChange(Exception):
|
|||
class ProtocolHandler(object):
|
||||
def __init__(self, c):
|
||||
self.c = c
|
||||
"""@type : libmproxy.proxy.ConnectionHandler"""
|
||||
|
||||
def handle_messages(self):
|
||||
"""
|
||||
|
@ -27,13 +28,17 @@ class ProtocolHandler(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
from . import http, tcp
|
||||
|
||||
from .http import HTTPHandler
|
||||
protocols = dict(
|
||||
http = dict(handler=http.HTTPHandler, flow=http.HTTPFlow),
|
||||
tcp = dict(handler=tcp.TCPHandler),
|
||||
)
|
||||
|
||||
|
||||
def _handler(conntype, connection_handler):
|
||||
if conntype == "http":
|
||||
return HTTPHandler(connection_handler)
|
||||
if conntype in protocols:
|
||||
return protocols[conntype]["handler"](connection_handler)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
import Cookie
|
||||
import Cookie, urllib, urlparse, time, copy
|
||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||
import urllib
|
||||
import urlparse
|
||||
import time
|
||||
import copy
|
||||
from ..flow import SimpleStateObject
|
||||
from netlib import http, tcp, http_status
|
||||
from netlib.odict import ODict, ODictCaseless
|
||||
import netlib.utils
|
||||
from .. import encoding, utils, version, filt, controller
|
||||
from ..proxy import ProxyError, ServerConnection, ClientConnection
|
||||
from netlib import http, tcp, http_status, stateobject, odict
|
||||
from netlib.odict import ODict, ODictCaseless
|
||||
from . import ProtocolHandler, ConnectionTypeChange, KILL
|
||||
import libmproxy.flow
|
||||
from .. import encoding, utils, version, filt, controller
|
||||
from ..proxy import ProxyError, ClientConnection, ServerConnection
|
||||
|
||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
CONTENT_MISSING = 0
|
||||
|
@ -57,7 +51,7 @@ class decoded(object):
|
|||
if self.ce:
|
||||
self.o.encode(self.ce)
|
||||
|
||||
|
||||
# FIXME: Move out of http
|
||||
class BackreferenceMixin(object):
|
||||
"""
|
||||
If an attribute from the _backrefattr tuple is set,
|
||||
|
@ -73,12 +67,10 @@ class BackreferenceMixin(object):
|
|||
def __setattr__(self, key, value):
|
||||
super(BackreferenceMixin, self).__setattr__(key, value)
|
||||
if key in self._backrefattr and value is not None:
|
||||
# check if there is already a different object set as backref
|
||||
assert (getattr(value, self._backrefname, self) or self) is self
|
||||
setattr(value, self._backrefname, self)
|
||||
|
||||
# FIXME: Move out of http
|
||||
class Error(SimpleStateObject):
|
||||
class Error(stateobject.SimpleStateObject):
|
||||
"""
|
||||
An Error.
|
||||
|
||||
|
@ -107,7 +99,7 @@ class Error(SimpleStateObject):
|
|||
return c
|
||||
|
||||
# FIXME: Move out of http
|
||||
class Flow(SimpleStateObject, BackreferenceMixin):
|
||||
class Flow(stateobject.SimpleStateObject, BackreferenceMixin):
|
||||
def __init__(self, conntype, client_conn, server_conn, error):
|
||||
self.conntype = conntype
|
||||
self.client_conn = client_conn
|
||||
|
@ -167,7 +159,7 @@ class Flow(SimpleStateObject, BackreferenceMixin):
|
|||
self._backup = None
|
||||
|
||||
|
||||
class HTTPMessage(SimpleStateObject):
|
||||
class HTTPMessage(stateobject.SimpleStateObject):
|
||||
def __init__(self):
|
||||
self.flow = None # Will usually set by backref mixin
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
from . import ProtocolHandler
|
||||
import select, socket
|
||||
from cStringIO import StringIO
|
||||
|
||||
|
||||
class TCPHandler(ProtocolHandler):
|
||||
"""
|
||||
TCPHandler acts as a generic TCP forwarder.
|
||||
Data will be .log()ed, but not stored any further.
|
||||
"""
|
||||
def handle_messages(self):
|
||||
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
|
||||
while not self.c.close:
|
||||
r, _, _ = select.select(conns, [], [], 10)
|
||||
for rfile in r:
|
||||
if self.c.client_conn.rfile == rfile:
|
||||
src, dst = self.c.client_conn, self.c.server_conn
|
||||
src_str, dst_str = "client", "server"
|
||||
else:
|
||||
dst, src = self.c.client_conn, self.c.server_conn
|
||||
dst_str, src_str = "client", "server"
|
||||
|
||||
data = StringIO()
|
||||
while range(4096):
|
||||
# Do non-blocking select() to see if there is further data on in the buffer.
|
||||
r, _, _ = select.select([rfile], [], [], 0)
|
||||
if len(r):
|
||||
d = rfile.read(1)
|
||||
if d == "": # connection closed
|
||||
break
|
||||
data.write(d)
|
||||
|
||||
"""
|
||||
OpenSSL Connections have an internal buffer that might contain data altough everything is read
|
||||
from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer,
|
||||
so we can read it completely at once.
|
||||
"""
|
||||
if src.ssl_established:
|
||||
data.write(rfile.read(src.connection.pending()))
|
||||
else: # no data left, but not closed yet
|
||||
break
|
||||
data = data.getvalue()
|
||||
|
||||
if data == "": # no data received, rfile is closed
|
||||
self.c.log("Close writing connection to %s" % dst_str)
|
||||
conns.remove(rfile)
|
||||
if dst.ssl_established:
|
||||
dst.connection.shutdown()
|
||||
else:
|
||||
dst.connection.shutdown(socket.SHUT_WR)
|
||||
if len(conns) == 0:
|
||||
self.c.close = True
|
||||
break
|
||||
|
||||
self.c.log("%s -> %s" % (src_str, dst_str), ["\r\n" + data])
|
||||
dst.wfile.write(data)
|
||||
dst.wfile.flush()
|
|
@ -1,7 +1,7 @@
|
|||
import os, socket, time, threading
|
||||
from OpenSSL import SSL
|
||||
from netlib import tcp, http, certutils, http_auth
|
||||
import utils, flow, version, platform, controller
|
||||
from netlib import tcp, http, certutils, http_auth, stateobject
|
||||
import utils, version, platform, controller
|
||||
|
||||
|
||||
TRANSPARENT_SSL_PORTS = [443, 8443]
|
||||
|
@ -34,7 +34,7 @@ class ProxyConfig:
|
|||
self.certstore = certutils.CertStore()
|
||||
|
||||
|
||||
class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
||||
class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
|
||||
def __init__(self, client_connection, address, server):
|
||||
tcp.BaseHandler.__init__(self, client_connection, address, server)
|
||||
|
||||
|
@ -46,7 +46,8 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
|||
timestamp_start=float,
|
||||
timestamp_end=float,
|
||||
timestamp_ssl_setup=float,
|
||||
# FIXME: Add missing attributes
|
||||
address=tcp.Address,
|
||||
clientcert=certutils.SSLCert
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -62,7 +63,7 @@ class ClientConnection(tcp.BaseHandler, flow.SimpleStateObject):
|
|||
self.timestamp_end = utils.timestamp()
|
||||
|
||||
|
||||
class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
|
||||
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
|
||||
def __init__(self, address):
|
||||
tcp.TCPClient.__init__(self, address)
|
||||
|
||||
|
@ -78,12 +79,14 @@ class ServerConnection(tcp.TCPClient, flow.SimpleStateObject):
|
|||
timestamp_end=float,
|
||||
timestamp_tcp_setup=float,
|
||||
timestamp_ssl_setup=float,
|
||||
# FIXME: Add missing attributes
|
||||
address=tcp.Address,
|
||||
source_address=tcp.Address,
|
||||
cert=certutils.SSLCert
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
raise NotImplementedError # FIXME
|
||||
raise NotImplementedError # FIXME
|
||||
|
||||
def connect(self):
|
||||
self.timestamp_start = utils.timestamp()
|
||||
|
@ -172,33 +175,34 @@ class ConnectionHandler:
|
|||
self.determine_conntype()
|
||||
|
||||
try:
|
||||
# Can we already identify the target server and connect to it?
|
||||
server_address = None
|
||||
if self.config.forward_proxy:
|
||||
server_address = self.config.forward_proxy[1:]
|
||||
else:
|
||||
if self.config.reverse_proxy:
|
||||
server_address = self.config.reverse_proxy[1:]
|
||||
elif self.config.transparent_proxy:
|
||||
server_address = self.config.transparent_proxy["resolver"].original_addr(
|
||||
self.client_conn.connection)
|
||||
if not server_address:
|
||||
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
|
||||
self.log("transparent to %s:%s" % server_address)
|
||||
try:
|
||||
# Can we already identify the target server and connect to it?
|
||||
server_address = None
|
||||
if self.config.forward_proxy:
|
||||
server_address = self.config.forward_proxy[1:]
|
||||
else:
|
||||
if self.config.reverse_proxy:
|
||||
server_address = self.config.reverse_proxy[1:]
|
||||
elif self.config.transparent_proxy:
|
||||
server_address = self.config.transparent_proxy["resolver"].original_addr(
|
||||
self.client_conn.connection)
|
||||
if not server_address:
|
||||
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
|
||||
self.log("transparent to %s:%s" % server_address)
|
||||
|
||||
if server_address:
|
||||
self.establish_server_connection(server_address)
|
||||
self._handle_ssl()
|
||||
if server_address:
|
||||
self.establish_server_connection(server_address)
|
||||
self._handle_ssl()
|
||||
|
||||
while not self.close:
|
||||
try:
|
||||
protocol.handle_messages(self.conntype, self)
|
||||
except protocol.ConnectionTypeChange:
|
||||
continue
|
||||
while not self.close:
|
||||
try:
|
||||
protocol.handle_messages(self.conntype, self)
|
||||
except protocol.ConnectionTypeChange:
|
||||
continue
|
||||
|
||||
# FIXME: Do we want to persist errors?
|
||||
except (ProxyError, tcp.NetLibError), e:
|
||||
protocol.handle_error(self.conntype, self, e)
|
||||
# FIXME: Do we want to persist errors?
|
||||
except (ProxyError, tcp.NetLibError), e:
|
||||
protocol.handle_error(self.conntype, self, e)
|
||||
except Exception, e:
|
||||
self.log(e.__class__)
|
||||
import traceback
|
||||
|
@ -250,7 +254,7 @@ class ConnectionHandler:
|
|||
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
||||
"""
|
||||
# TODO: Implement SSL pass-through handling and change conntype
|
||||
if self.server_conn.address.host == "ycombinator.com":
|
||||
if self.server_conn.address.host == "news.ycombinator.com":
|
||||
self.conntype = "tcp"
|
||||
|
||||
if server:
|
||||
|
@ -265,8 +269,8 @@ class ConnectionHandler:
|
|||
handle_sni=self.handle_sni)
|
||||
|
||||
def server_reconnect(self, no_ssl=False):
|
||||
self.log("server reconnect")
|
||||
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
||||
self.log("server reconnect (ssl: %s, sni: %s)" % (had_ssl, sni))
|
||||
self.establish_server_connection(self.server_conn.address)
|
||||
if had_ssl and not no_ssl:
|
||||
self.sni = sni
|
||||
|
|
Loading…
Reference in New Issue