unify ipv4/ipv6 address handling
This commit is contained in:
parent
94e530ec4f
commit
17f09aa0af
|
@ -5,9 +5,9 @@ python:
|
|||
install:
|
||||
- "pip install coveralls --use-mirrors"
|
||||
- "pip install nose-cov --use-mirrors"
|
||||
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git"
|
||||
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git@tcp_proxy"
|
||||
- "pip install -r requirements.txt --use-mirrors"
|
||||
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git"
|
||||
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git@tcp_proxy"
|
||||
# command to run tests, e.g. python setup.py test
|
||||
script:
|
||||
- "nosetests --with-cov --cov-report term-missing"
|
||||
|
|
|
@ -115,7 +115,7 @@ class HTTPResponse(HTTPMessage):
|
|||
|
||||
class HTTPRequest(HTTPMessage):
|
||||
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
|
||||
timestamp_start, timestamp_end, form_out=None, ip=None):
|
||||
timestamp_start, timestamp_end, form_out=None):
|
||||
self.form_in = form_in
|
||||
self.method = method
|
||||
self.scheme = scheme
|
||||
|
@ -129,7 +129,6 @@ class HTTPRequest(HTTPMessage):
|
|||
self.timestamp_end = timestamp_end
|
||||
|
||||
self.form_out = form_out or self.form_in
|
||||
self.ip = ip # resolved ip address
|
||||
assert isinstance(headers, ODictCaseless)
|
||||
|
||||
#FIXME: Compatibility Fix
|
||||
|
@ -352,7 +351,7 @@ class HTTPHandler(ProtocolHandler):
|
|||
if request.form_in == "authority":
|
||||
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
|
||||
if directly_addressed_at_mitmproxy:
|
||||
self.c.establish_server_connection(request.host, request.port)
|
||||
self.c.establish_server_connection((request.host, request.port))
|
||||
self.c.client_conn.wfile.write(
|
||||
'HTTP/1.1 200 Connection established\r\n' +
|
||||
('Proxy-agent: %s\r\n' % self.c.server_version) +
|
||||
|
@ -369,7 +368,7 @@ class HTTPHandler(ProtocolHandler):
|
|||
request.form_out = "origin"
|
||||
if ((not self.c.server_conn) or
|
||||
(self.c.server_conn.address != (request.host, request.port))):
|
||||
self.c.establish_server_connection(request.host, request.port)
|
||||
self.c.establish_server_connection((request.host, request.port))
|
||||
else:
|
||||
raise http.HttpError(400, "Invalid Request")
|
||||
|
||||
|
|
|
@ -40,18 +40,13 @@ class ProxyConfig:
|
|||
|
||||
|
||||
class ClientConnection(tcp.BaseHandler):
|
||||
def __init__(self, client_connection, host, port):
|
||||
tcp.BaseHandler.__init__(self, client_connection)
|
||||
self.host, self.port = host, port
|
||||
def __init__(self, client_connection, address):
|
||||
tcp.BaseHandler.__init__(self, client_connection, address)
|
||||
|
||||
self.timestamp_start = utils.timestamp()
|
||||
self.timestamp_end = None
|
||||
self.timestamp_ssl_setup = None
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
return self.host, self.port
|
||||
|
||||
def convert_to_ssl(self, *args, **kwargs):
|
||||
tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
|
||||
self.timestamp_ssl_setup = utils.timestamp()
|
||||
|
@ -62,21 +57,19 @@ class ClientConnection(tcp.BaseHandler):
|
|||
|
||||
|
||||
class ServerConnection(tcp.TCPClient):
|
||||
def __init__(self, host, port):
|
||||
tcp.TCPClient.__init__(self, host, port)
|
||||
def __init__(self, address):
|
||||
tcp.TCPClient.__init__(self, address)
|
||||
|
||||
self.peername = None
|
||||
self.timestamp_start = None
|
||||
self.timestamp_end = None
|
||||
self.timestamp_tcp_setup = None
|
||||
self.timestamp_ssl_setup = None
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
return self.host, self.port
|
||||
|
||||
def connect(self):
|
||||
self.timestamp_start = utils.timestamp()
|
||||
tcp.TCPClient.connect(self)
|
||||
self.peername = self.connection.getpeername()
|
||||
self.timestamp_tcp_setup = utils.timestamp()
|
||||
|
||||
def establish_ssl(self, clientcerts, sni):
|
||||
|
@ -125,7 +118,7 @@ class RequestReplayThread(threading.Thread):
|
|||
class ConnectionHandler:
|
||||
def __init__(self, config, client_connection, client_address, server, channel, server_version):
|
||||
self.config = config
|
||||
self.client_conn = ClientConnection(client_connection, *client_address)
|
||||
self.client_conn = ClientConnection(client_connection, client_address)
|
||||
self.server_conn = None
|
||||
self.channel, self.server_version = channel, server_version
|
||||
|
||||
|
@ -140,9 +133,9 @@ class ConnectionHandler:
|
|||
self.mode = "transparent"
|
||||
|
||||
def del_server_connection(self):
|
||||
if self.server_conn and self.server_conn.connection:
|
||||
if self.server_conn and self.server_conn.connection:
|
||||
self.server_conn.finish()
|
||||
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)])
|
||||
self.log("serverdisconnect", ["%s:%s" % self.server_conn.address])
|
||||
self.channel.tell("serverdisconnect", self)
|
||||
self.server_conn = None
|
||||
self.sni = None
|
||||
|
@ -169,7 +162,7 @@ class ConnectionHandler:
|
|||
self.determine_conntype()
|
||||
|
||||
if server_address:
|
||||
self.establish_server_connection(*server_address)
|
||||
self.establish_server_connection(server_address)
|
||||
self._handle_ssl()
|
||||
|
||||
while not self.close:
|
||||
|
@ -191,7 +184,7 @@ class ConnectionHandler:
|
|||
Check if we can already identify SSL connections.
|
||||
"""
|
||||
if self.config.transparent_proxy:
|
||||
client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"])
|
||||
client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"])
|
||||
elif self.config.reverse_proxy:
|
||||
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
|
||||
# TODO: Make protocol generic (as with transparent proxies)
|
||||
|
@ -205,18 +198,18 @@ class ConnectionHandler:
|
|||
#TODO: Add ruleset to select correct protocol depending on mode/target port etc.
|
||||
self.conntype = "http"
|
||||
|
||||
def establish_server_connection(self, host, port):
|
||||
def establish_server_connection(self, address):
|
||||
"""
|
||||
Establishes a new server connection to the given server
|
||||
If there is already an existing server connection, it will be killed.
|
||||
"""
|
||||
self.del_server_connection()
|
||||
self.server_conn = ServerConnection(host, port)
|
||||
self.server_conn = ServerConnection(address)
|
||||
try:
|
||||
self.server_conn.connect()
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
self.log("serverconnect", ["%s:%s" % (host, port)])
|
||||
self.log("serverconnect", ["%s:%s" % address])
|
||||
self.channel.tell("serverconnect", self)
|
||||
|
||||
def establish_ssl(self, client=False, server=False):
|
||||
|
@ -227,7 +220,7 @@ class ConnectionHandler:
|
|||
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
||||
"""
|
||||
# TODO: Implement SSL pass-through handling and change conntype
|
||||
if self.server_conn.host == "ycombinator.com":
|
||||
if self.server_conn.address.host == "ycombinator.com":
|
||||
self.conntype = "tcp"
|
||||
|
||||
if server:
|
||||
|
@ -244,14 +237,14 @@ class ConnectionHandler:
|
|||
def server_reconnect(self, no_ssl=False):
|
||||
self.log("server reconnect")
|
||||
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
||||
self.establish_server_connection(*self.server_conn.address)
|
||||
self.establish_server_connection(self.server_conn.address)
|
||||
if had_ssl and not no_ssl:
|
||||
self.sni = sni
|
||||
self.establish_ssl(server=True)
|
||||
|
||||
def log(self, msg, subs=()):
|
||||
msg = [
|
||||
"%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg)
|
||||
"%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg)
|
||||
]
|
||||
for i in subs:
|
||||
msg.append(" -> " + i)
|
||||
|
@ -263,7 +256,7 @@ class ConnectionHandler:
|
|||
with open(self.config.certfile, "rb") as f:
|
||||
return certutils.SSLCert.from_pem(f.read())
|
||||
else:
|
||||
host = self.server_conn.host
|
||||
host = self.server_conn.address.host
|
||||
sans = []
|
||||
if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
|
||||
upstream_cert = self.server_conn.cert
|
||||
|
@ -307,14 +300,14 @@ class ProxyServer(tcp.TCPServer):
|
|||
allow_reuse_address = True
|
||||
bound = True
|
||||
|
||||
def __init__(self, config, port, address='', server_version=version.NAMEVERSION):
|
||||
def __init__(self, config, port, host='', server_version=version.NAMEVERSION):
|
||||
"""
|
||||
Raises ProxyServerError if there's a startup problem.
|
||||
"""
|
||||
self.config, self.port, self.address = config, port, address
|
||||
self.config = config
|
||||
self.server_version = server_version
|
||||
try:
|
||||
tcp.TCPServer.__init__(self, (address, port))
|
||||
tcp.TCPServer.__init__(self, (host, port))
|
||||
except socket.error, v:
|
||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||
self.channel = None
|
||||
|
|
|
@ -46,7 +46,7 @@ class CommonMixin:
|
|||
assert l.response.code == 304
|
||||
|
||||
def test_invalid_http(self):
|
||||
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
|
||||
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
|
||||
t.connect()
|
||||
t.wfile.write("invalid\r\n\r\n")
|
||||
t.wfile.flush()
|
||||
|
@ -70,7 +70,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin):
|
|||
assert "ValueError" in ret.content
|
||||
|
||||
def test_invalid_connect(self):
|
||||
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
|
||||
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
|
||||
t.connect()
|
||||
t.wfile.write("CONNECT invalid\n\n")
|
||||
t.wfile.flush()
|
||||
|
|
Loading…
Reference in New Issue