connections: coverage++ (#2064)

This commit is contained in:
Thomas Kriechbaumer 2017-02-26 20:50:52 +01:00 committed by GitHub
parent b33d568e04
commit 9b6986ea87
5 changed files with 211 additions and 41 deletions

View File

@ -54,14 +54,20 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return bool(self.connection) and not self.finished return bool(self.connection) and not self.finished
def __repr__(self): def __repr__(self):
if self.ssl_established:
tls = "[{}] ".format(self.tls_version)
else:
tls = ""
if self.alpn_proto_negotiated: if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format( alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
) )
else: else:
alpn = "" alpn = ""
return "<ClientConnection: {ssl}{alpn}{host}:{port}>".format(
ssl="[ssl] " if self.ssl_established else "", return "<ClientConnection: {tls}{alpn}{host}:{port}>".format(
tls=tls,
alpn=alpn, alpn=alpn,
host=self.address[0], host=self.address[0],
port=self.address[1], port=self.address[1],
@ -71,6 +77,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tuple, address=tuple,
ssl_established=bool, ssl_established=bool,
@ -100,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
address=dict(address=address, use_ipv6=False), address=address,
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
ssl_established=False, ssl_established=False,
@ -144,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
cert: The certificate presented by the remote during the TLS handshake cert: The certificate presented by the remote during the TLS handshake
sni: Server Name Indication sent by the proxy during the TLS handshake sni: Server Name Indication sent by the proxy during the TLS handshake
alpn_proto_negotiated: The negotiated application protocol alpn_proto_negotiated: The negotiated application protocol
tls_version: TLS version
via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode) via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)
timestamp_start: Connection start timestamp timestamp_start: Connection start timestamp
timestamp_tcp_setup: TCP ACK received timestamp timestamp_tcp_setup: TCP ACK received timestamp
@ -155,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
self.alpn_proto_negotiated = None self.alpn_proto_negotiated = None
self.tls_version = None
self.via = None self.via = None
self.timestamp_start = None self.timestamp_start = None
self.timestamp_end = None self.timestamp_end = None
@ -166,19 +178,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __repr__(self): def __repr__(self):
if self.ssl_established and self.sni: if self.ssl_established and self.sni:
ssl = "[ssl: {0}] ".format(self.sni) tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni)
elif self.ssl_established: elif self.ssl_established:
ssl = "[ssl] " tls = "[{}] ".format(self.tls_version or "TLS")
else: else:
ssl = "" tls = ""
if self.alpn_proto_negotiated: if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format( alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
) )
else: else:
alpn = "" alpn = ""
return "<ServerConnection: {ssl}{alpn}{host}:{port}>".format( return "<ServerConnection: {tls}{alpn}{host}:{port}>".format(
ssl=ssl, tls=tls,
alpn=alpn, alpn=alpn,
host=self.address[0], host=self.address[0],
port=self.address[1], port=self.address[1],
@ -188,6 +200,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tuple, address=tuple,
ip_address=tuple, ip_address=tuple,
@ -196,6 +212,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
cert=certs.SSLCert, cert=certs.SSLCert,
sni=str, sni=str,
alpn_proto_negotiated=bytes, alpn_proto_negotiated=bytes,
tls_version=str,
timestamp_start=float, timestamp_start=float,
timestamp_tcp_setup=float, timestamp_tcp_setup=float,
timestamp_ssl_setup=float, timestamp_ssl_setup=float,
@ -211,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
address=dict(address=address, use_ipv6=False), address=address,
ip_address=dict(address=address, use_ipv6=False), ip_address=address,
cert=None, cert=None,
sni=None, sni=None,
alpn_proto_negotiated=None, alpn_proto_negotiated=None,
source_address=dict(address=('', 0), use_ipv6=False), tls_version=None,
source_address=('', 0),
ssl_established=False, ssl_established=False,
timestamp_start=None, timestamp_start=None,
timestamp_tcp_setup=None, timestamp_tcp_setup=None,
@ -253,6 +271,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
self.sni = sni self.sni = sni
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
self.tls_version = self.connection.get_protocol_version_name()
self.timestamp_ssl_setup = time.time() self.timestamp_ssl_setup = time.time()
def finish(self): def finish(self):

View File

@ -99,6 +99,9 @@ def convert_100_200(data):
def convert_200_300(data): def convert_200_300(data):
data["version"] = (3, 0, 0) data["version"] = (3, 0, 0)
data["client_conn"]["mitmcert"] = None data["client_conn"]["mitmcert"] = None
data["server_conn"]["tls_version"] = None
if data["server_conn"]["via"]:
data["server_conn"]["via"]["tls_version"] = None
return data return data

View File

@ -1,3 +1,5 @@
import io
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.test import tutils from mitmproxy.test import tutils
from mitmproxy import tcp from mitmproxy import tcp
@ -156,6 +158,8 @@ def tclient_conn():
tls_version="TLSv1.2", tls_version="TLSv1.2",
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c return c
@ -175,9 +179,12 @@ def tserver_conn():
ssl_established=False, ssl_established=False,
sni="address", sni="address",
alpn_proto_negotiated=None, alpn_proto_negotiated=None,
tls_version=None,
via=None, via=None,
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c return c

View File

@ -35,7 +35,6 @@ exclude =
mitmproxy/proxy/server.py mitmproxy/proxy/server.py
mitmproxy/tools/ mitmproxy/tools/
mitmproxy/certs.py mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/controller.py mitmproxy/controller.py
mitmproxy/export.py mitmproxy/export.py
mitmproxy/flow.py mitmproxy/flow.py
@ -52,7 +51,6 @@ exclude =
mitmproxy/addons/onboardingapp/app.py mitmproxy/addons/onboardingapp/app.py
mitmproxy/addons/termlog.py mitmproxy/addons/termlog.py
mitmproxy/certs.py mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/contentviews/base.py mitmproxy/contentviews/base.py
mitmproxy/contentviews/wbxml.py mitmproxy/contentviews/wbxml.py
mitmproxy/contentviews/xml_html.py mitmproxy/contentviews/xml_html.py

View File

@ -1,13 +1,57 @@
import socket
import os
import threading
import ssl
import OpenSSL
import pytest
from unittest import mock from unittest import mock
from mitmproxy import connections from mitmproxy import connections
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net import tcp
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.test import tutils
from .net import tservers
from pathod import test from pathod import test
class TestClientConnection: class TestClientConnection:
def test_send(self):
c = tflow.tclient_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tclient_conn()
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS' not in repr(c)
c.alpn_proto_negotiated = None
c.tls_established = True
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
def test_tls_established_property(self):
c = tflow.tclient_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ClientConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_state(self): def test_state(self):
c = tflow.tclient_conn() c = tflow.tclient_conn()
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
@ -24,44 +68,143 @@ class TestClientConnection:
c3 = c.copy() c3 = c.copy()
assert c3.get_state() == c.get_state() assert c3.get_state() == c.get_state()
assert str(c)
class TestServerConnection: class TestServerConnection:
def test_send(self):
c = tflow.tserver_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tserver_conn()
c.sni = 'foobar'
c.tls_established = True
c.alpn_proto_negotiated = b'h2'
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS: foobar' in repr(c)
c.sni = None
c.tls_established = True
c.alpn_proto_negotiated = None
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
c.sni = None
c.tls_established = False
assert 'TLS' not in repr(c)
def test_tls_established_property(self):
c = tflow.tserver_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ServerConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_simple(self): def test_simple(self):
self.d = test.Daemon() d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port)) c = connections.ServerConnection((d.IFACE, d.port))
sc.connect() c.connect()
f = tflow.tflow() f = tflow.tflow()
f.server_conn = sc f.server_conn = c
f.request.path = "/p/200:da" f.request.path = "/p/200:da"
# use this protocol just to assemble - not for actual sending # use this protocol just to assemble - not for actual sending
sc.wfile.write(http1.assemble_request(f.request)) c.wfile.write(http1.assemble_request(f.request))
sc.wfile.flush() c.wfile.flush()
assert http1.read_response(sc.rfile, f.request, 1000) assert http1.read_response(c.rfile, f.request, 1000)
assert self.d.last_log() assert d.last_log()
sc.finish() c.finish()
self.d.shutdown() d.shutdown()
def test_terminate_error(self): def test_terminate_error(self):
self.d = test.Daemon() d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port)) c = connections.ServerConnection((d.IFACE, d.port))
sc.connect() c.connect()
sc.connection = mock.Mock() c.connection = mock.Mock()
sc.connection.recv = mock.Mock(return_value=False) c.connection.recv = mock.Mock(return_value=False)
sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
sc.finish() c.finish()
self.d.shutdown() d.shutdown()
def test_repr(self): def test_sni(self):
sc = tflow.tserver_conn() c = connections.ServerConnection(('', 1234))
assert "address:22" in repr(sc) with pytest.raises(ValueError, matches='sni must be str, not '):
assert "ssl" not in repr(sc) c.establish_ssl(None, b'foobar')
sc.ssl_established = True
assert "ssl" in repr(sc)
sc.sni = "foo" class TestClientConnectionTLS:
assert "foo" in repr(sc)
@pytest.mark.parametrize("sni", [
None,
"example.com"
])
def test_tls_with_sni(self, sni):
address = ('127.0.0.1', 0)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(address)
sock.listen()
address = sock.getsockname()
def client_run():
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
s = socket.create_connection(address)
s = ctx.wrap_socket(s, server_hostname=sni)
s.send(b'foobar')
s.shutdown(socket.SHUT_RDWR)
threading.Thread(target=client_run).start()
connection, client_address = sock.accept()
c = connections.ClientConnection(connection, client_address, None)
cert = tutils.test_data.path("mitmproxy/net/data/server.crt")
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read())
c.convert_to_ssl(cert, key)
assert c.connected()
assert c.sni == sni
assert c.tls_established
assert c.rfile.read(6) == b'foobar'
c.finish()
class TestServerConnectionTLS(tservers.ServerTestBase):
ssl = True
class handler(tcp.BaseHandler):
def handle(self):
self.finish()
@pytest.mark.parametrize("clientcert", [
None,
tutils.test_data.path("mitmproxy/data/clientcert"),
os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"),
])
def test_tls(self, clientcert):
c = connections.ServerConnection(("127.0.0.1", self.port))
c.connect()
c.establish_ssl(clientcert, "foo.com")
assert c.connected()
assert c.sni == "foo.com"
assert c.tls_established
c.close()
c.finish()