diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py index 3e1a09283..570e89a9a 100644 --- a/mitmproxy/models/connections.py +++ b/mitmproxy/models/connections.py @@ -8,7 +8,6 @@ import six from mitmproxy import stateobject from netlib import certutils -from netlib import strutils from netlib import tcp @@ -162,7 +161,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): source_address=tcp.Address, ssl_established=bool, cert=certutils.SSLCert, - sni=bytes, + sni=str, timestamp_start=float, timestamp_tcp_setup=float, timestamp_ssl_setup=float, @@ -206,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.wfile.flush() def establish_ssl(self, clientcerts, sni, **kwargs): + if sni and not isinstance(sni, six.string_types): + raise ValueError("sni must be str, not " + type(sni).__name__) clientcert = None if clientcerts: if os.path.isfile(clientcerts): @@ -217,7 +218,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): if os.path.exists(path): clientcert = path - self.convert_to_ssl(cert=clientcert, sni=strutils.always_bytes(sni), **kwargs) + self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) self.sni = sni self.timestamp_ssl_setup = time.time() diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index 0e4f80cb4..f4993b7a2 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection from mitmproxy.models.connections import ServerConnection from netlib import version +from typing import Optional # noqa class Error(stateobject.StateObject): @@ -70,18 +71,13 @@ class Flow(stateobject.StateObject): def __init__(self, type, client_conn, server_conn, live=None): self.type = type self.id = str(uuid.uuid4()) - self.client_conn = client_conn - """@type: ClientConnection""" - self.server_conn = server_conn - """@type: ServerConnection""" + self.client_conn = client_conn # type: ClientConnection + self.server_conn = server_conn # type: ServerConnection self.live = live - """@type: LiveConnection""" - self.error = None - """@type: Error""" - self.intercepted = False - """@type: bool""" - self._backup = None + self.error = None # type: Error + self.intercepted = False # type: bool + self._backup = None # type: Optional[Flow] self.reply = None _stateobject_attributes = dict( diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 9f883b2b0..8ef344930 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -10,6 +10,7 @@ import netlib.exceptions from mitmproxy import exceptions from mitmproxy.contrib.tls import _constructs from mitmproxy.protocol import base +from netlib import utils # taken from https://testssl.sh/openssl-rfc.mappping.html @@ -274,10 +275,11 @@ class TlsClientHello(object): is_valid_sni_extension = ( extension.type == 0x00 and len(extension.server_names) == 1 and - extension.server_names[0].type == 0 + extension.server_names[0].type == 0 and + utils.is_valid_host(extension.server_names[0].name) ) if is_valid_sni_extension: - return extension.server_names[0].name + return extension.server_names[0].name.decode("idna") @property def alpn_protocols(self): @@ -403,13 +405,14 @@ class TlsLayer(base.Layer): self._establish_tls_with_server() def set_server_tls(self, server_tls, sni=None): + # type: (bool, Union[six.text_type, None, False]) -> None """ Set the TLS settings for the next server connection that will be established. This function will not alter an existing connection. Args: server_tls: Shall we establish TLS with the server? - sni: ``bytes`` for a custom SNI value, + sni: ``str`` for a custom SNI value, ``None`` for the client SNI value, ``False`` if no SNI value should be sent. """ @@ -602,9 +605,9 @@ class TlsLayer(base.Layer): host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. if self._client_hello.sni: - sans.add(self._client_hello.sni) + sans.add(self._client_hello.sni.encode("idna")) if self._custom_server_sni: - sans.add(self._custom_server_sni) + sans.add(self._custom_server_sni.encode("idna")) # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. # In other words, the Common Name is irrelevant then. diff --git a/netlib/tcp.py b/netlib/tcp.py index 69dafc1fd..cf099eddc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -676,7 +676,7 @@ class TCPClient(_Connection): self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni - self.connection.set_tlsext_host_name(sni) + self.connection.set_tlsext_host_name(sni.encode("idna")) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -705,7 +705,7 @@ class TCPClient(_Connection): if self.cert.cn: crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] if sni: - hostname = sni.decode("ascii", "strict") + hostname = sni else: hostname = "no-hostname" ssl_match_hostname.match_hostname(crt, hostname) diff --git a/netlib/utils.py b/netlib/utils.py index 79340cbd6..23c16dc36 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -73,11 +73,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? bool """ Checks if a hostname is valid. - - Args: - host (bytes): The hostname """ try: host.decode("idna") diff --git a/pathod/pathod.py b/pathod/pathod.py index 3df86aae4..7087cba6e 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler): self.http2_framedump = http2_framedump def handle_sni(self, connection): - self.sni = connection.get_servername() + sni = connection.get_servername() + if sni: + sni = sni.decode("idna") + self.sni = sni def http_serve_crafted(self, crafted, logctx): error, crafted = self.server.check_policy( diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 1bbef975e..0ab7624e6 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -100,10 +100,10 @@ class CommonMixin: if not self.ssl: return - f = self.pathod("304", sni=b"testserver.com") + f = self.pathod("304", sni="testserver.com") assert f.status_code == 304 log = self.server.last_log() - assert log["request"]["sni"] == b"testserver.com" + assert log["request"]["sni"] == "testserver.com" class TcpMixin: @@ -498,7 +498,7 @@ class TestHttps2Http(tservers.ReverseProxyTest): assert p.request("get:'/p/200'").status_code == 200 def test_sni(self): - p = self.pathoc(ssl=True, sni=b"example.com") + p = self.pathoc(ssl=True, sni="example.com") assert p.request("get:'/p/200'").status_code == 200 assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index 5aade60c5..d0a09035e 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -130,7 +130,7 @@ def tserver_conn(): timestamp_ssl_setup=3, timestamp_end=4, ssl_established=False, - sni=b"address", + sni="address", via=None )) c.reply = controller.DummyReply() diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 590bcc01e..273427d51 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) + c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"mitmproxy.org", + sni="mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("data/verificationcerts/") ) @@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") - assert c.sni == b"foo.com" + c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" @@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == b"['RC4-SHA']" @@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert b"RC4-SHA" in c.rfile.readline() @@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") + tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase): tutils.raises( "cipher specification", c.convert_to_ssl, - sni=b"foo.com", + sni="foo.com", cipher_list="bogus" ) diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 28f9f0f88..361a863bb 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon): def test_sni(self): self.tval( ["get:/p/200"], - sni=b"foobar.com" + sni="foobar.com" ) log = self.d.log() - assert log[0]["request"]["sni"] == b"foobar.com" + assert log[0]["request"]["sni"] == "foobar.com" def test_showssl(self): assert "certificate chain" in self.tval(["get:/p/200"], showssl=True)