diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index 6d7c3c769..9359b67db 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -54,14 +54,20 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): return bool(self.connection) and not self.finished def __repr__(self): + if self.ssl_established: + tls = "[{}] ".format(self.tls_version) + else: + tls = "" + if self.alpn_proto_negotiated: alpn = "[ALPN: {}] ".format( strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) ) else: alpn = "" - return "".format( - ssl="[ssl] " if self.ssl_established else "", + + return "".format( + tls=tls, alpn=alpn, host=self.address[0], port=self.address[1], @@ -71,6 +77,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): def tls_established(self): return self.ssl_established + @tls_established.setter + def tls_established(self, value): + self.ssl_established = value + _stateobject_attributes = dict( address=tuple, ssl_established=bool, @@ -100,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): @classmethod def make_dummy(cls, address): return cls.from_state(dict( - address=dict(address=address, use_ipv6=False), + address=address, clientcert=None, mitmcert=None, ssl_established=False, @@ -144,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): cert: The certificate presented by the remote during the TLS handshake sni: Server Name Indication sent by the proxy during the TLS handshake alpn_proto_negotiated: The negotiated application protocol + tls_version: TLS version via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode) timestamp_start: Connection start timestamp timestamp_tcp_setup: TCP ACK received timestamp @@ -155,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) self.alpn_proto_negotiated = None + self.tls_version = None self.via = None self.timestamp_start = None self.timestamp_end = None @@ -166,19 +178,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): def __repr__(self): if self.ssl_established and self.sni: - ssl = "[ssl: {0}] ".format(self.sni) + tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni) elif self.ssl_established: - ssl = "[ssl] " + tls = "[{}] ".format(self.tls_version or "TLS") else: - ssl = "" + tls = "" if self.alpn_proto_negotiated: alpn = "[ALPN: {}] ".format( strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) ) else: alpn = "" - return "".format( - ssl=ssl, + return "".format( + tls=tls, alpn=alpn, host=self.address[0], port=self.address[1], @@ -188,6 +200,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): def tls_established(self): return self.ssl_established + @tls_established.setter + def tls_established(self, value): + self.ssl_established = value + _stateobject_attributes = dict( address=tuple, ip_address=tuple, @@ -196,6 +212,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): cert=certs.SSLCert, sni=str, alpn_proto_negotiated=bytes, + tls_version=str, timestamp_start=float, timestamp_tcp_setup=float, timestamp_ssl_setup=float, @@ -211,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): @classmethod def make_dummy(cls, address): return cls.from_state(dict( - address=dict(address=address, use_ipv6=False), - ip_address=dict(address=address, use_ipv6=False), + address=address, + ip_address=address, cert=None, sni=None, alpn_proto_negotiated=None, - source_address=dict(address=('', 0), use_ipv6=False), + tls_version=None, + source_address=('', 0), ssl_established=False, timestamp_start=None, timestamp_tcp_setup=None, @@ -253,6 +271,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) self.sni = sni self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() + self.tls_version = self.connection.get_protocol_version_name() self.timestamp_ssl_setup = time.time() def finish(self): diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py index 4c840da57..16cbc9fea 100644 --- a/mitmproxy/io_compat.py +++ b/mitmproxy/io_compat.py @@ -99,6 +99,9 @@ def convert_100_200(data): def convert_200_300(data): data["version"] = (3, 0, 0) data["client_conn"]["mitmcert"] = None + data["server_conn"]["tls_version"] = None + if data["server_conn"]["via"]: + data["server_conn"]["via"]["tls_version"] = None return data diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index f30d8b6f1..fd665055e 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -1,3 +1,5 @@ +import io + from mitmproxy.net import websockets from mitmproxy.test import tutils from mitmproxy import tcp @@ -156,6 +158,8 @@ def tclient_conn(): tls_version="TLSv1.2", )) c.reply = controller.DummyReply() + c.rfile = io.BytesIO() + c.wfile = io.BytesIO() return c @@ -175,9 +179,12 @@ def tserver_conn(): ssl_established=False, sni="address", alpn_proto_negotiated=None, + tls_version=None, via=None, )) c.reply = controller.DummyReply() + c.rfile = io.BytesIO() + c.wfile = io.BytesIO() return c diff --git a/setup.cfg b/setup.cfg index 1825f4341..79a873180 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,6 @@ exclude = mitmproxy/proxy/server.py mitmproxy/tools/ mitmproxy/certs.py - mitmproxy/connections.py mitmproxy/controller.py mitmproxy/export.py mitmproxy/flow.py @@ -52,7 +51,6 @@ exclude = mitmproxy/addons/onboardingapp/app.py mitmproxy/addons/termlog.py mitmproxy/certs.py - mitmproxy/connections.py mitmproxy/contentviews/base.py mitmproxy/contentviews/wbxml.py mitmproxy/contentviews/xml_html.py diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index fa23a53c0..0083f57cc 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1,13 +1,57 @@ +import socket +import os +import threading +import ssl +import OpenSSL +import pytest from unittest import mock from mitmproxy import connections from mitmproxy import exceptions +from mitmproxy.net import tcp from mitmproxy.net.http import http1 from mitmproxy.test import tflow +from mitmproxy.test import tutils +from .net import tservers from pathod import test class TestClientConnection: + + def test_send(self): + c = tflow.tclient_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tclient_conn() + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS' not in repr(c) + + c.alpn_proto_negotiated = None + c.tls_established = True + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + def test_tls_established_property(self): + c = tflow.tclient_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ClientConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + def test_state(self): c = tflow.tclient_conn() assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ @@ -24,44 +68,143 @@ class TestClientConnection: c3 = c.copy() assert c3.get_state() == c.get_state() - assert str(c) - class TestServerConnection: + def test_send(self): + c = tflow.tserver_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tserver_conn() + + c.sni = 'foobar' + c.tls_established = True + c.alpn_proto_negotiated = b'h2' + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS: foobar' in repr(c) + + c.sni = None + c.tls_established = True + c.alpn_proto_negotiated = None + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + c.sni = None + c.tls_established = False + assert 'TLS' not in repr(c) + + def test_tls_established_property(self): + c = tflow.tserver_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ServerConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + def test_simple(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() f = tflow.tflow() - f.server_conn = sc + f.server_conn = c f.request.path = "/p/200:da" # use this protocol just to assemble - not for actual sending - sc.wfile.write(http1.assemble_request(f.request)) - sc.wfile.flush() + c.wfile.write(http1.assemble_request(f.request)) + c.wfile.flush() - assert http1.read_response(sc.rfile, f.request, 1000) - assert self.d.last_log() + assert http1.read_response(c.rfile, f.request, 1000) + assert d.last_log() - sc.finish() - self.d.shutdown() + c.finish() + d.shutdown() def test_terminate_error(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() - sc.connection = mock.Mock() - sc.connection.recv = mock.Mock(return_value=False) - sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - sc.finish() - self.d.shutdown() + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() + c.connection = mock.Mock() + c.connection.recv = mock.Mock(return_value=False) + c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) + c.finish() + d.shutdown() - def test_repr(self): - sc = tflow.tserver_conn() - assert "address:22" in repr(sc) - assert "ssl" not in repr(sc) - sc.ssl_established = True - assert "ssl" in repr(sc) - sc.sni = "foo" - assert "foo" in repr(sc) + def test_sni(self): + c = connections.ServerConnection(('', 1234)) + with pytest.raises(ValueError, matches='sni must be str, not '): + c.establish_ssl(None, b'foobar') + + +class TestClientConnectionTLS: + + @pytest.mark.parametrize("sni", [ + None, + "example.com" + ]) + def test_tls_with_sni(self, sni): + address = ('127.0.0.1', 0) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen() + address = sock.getsockname() + + def client_run(): + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + s = socket.create_connection(address) + s = ctx.wrap_socket(s, server_hostname=sni) + s.send(b'foobar') + s.shutdown(socket.SHUT_RDWR) + threading.Thread(target=client_run).start() + + connection, client_address = sock.accept() + c = connections.ClientConnection(connection, client_address, None) + + cert = tutils.test_data.path("mitmproxy/net/data/server.crt") + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) + c.convert_to_ssl(cert, key) + assert c.connected() + assert c.sni == sni + assert c.tls_established + assert c.rfile.read(6) == b'foobar' + c.finish() + + +class TestServerConnectionTLS(tservers.ServerTestBase): + ssl = True + + class handler(tcp.BaseHandler): + def handle(self): + self.finish() + + @pytest.mark.parametrize("clientcert", [ + None, + tutils.test_data.path("mitmproxy/data/clientcert"), + os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), + ]) + def test_tls(self, clientcert): + c = connections.ServerConnection(("127.0.0.1", self.port)) + c.connect() + c.establish_ssl(clientcert, "foo.com") + assert c.connected() + assert c.sni == "foo.com" + assert c.tls_established + c.close() + c.finish()