connections: coverage++ (#2064)

This commit is contained in:
Thomas Kriechbaumer 2017-02-26 20:50:52 +01:00 committed by GitHub
parent b33d568e04
commit 9b6986ea87
5 changed files with 211 additions and 41 deletions

View File

@ -54,14 +54,20 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return bool(self.connection) and not self.finished
def __repr__(self):
if self.ssl_established:
tls = "[{}] ".format(self.tls_version)
else:
tls = ""
if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
)
else:
alpn = ""
return "<ClientConnection: {ssl}{alpn}{host}:{port}>".format(
ssl="[ssl] " if self.ssl_established else "",
return "<ClientConnection: {tls}{alpn}{host}:{port}>".format(
tls=tls,
alpn=alpn,
host=self.address[0],
port=self.address[1],
@ -71,6 +77,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
def tls_established(self):
return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict(
address=tuple,
ssl_established=bool,
@ -100,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod
def make_dummy(cls, address):
return cls.from_state(dict(
address=dict(address=address, use_ipv6=False),
address=address,
clientcert=None,
mitmcert=None,
ssl_established=False,
@ -144,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
cert: The certificate presented by the remote during the TLS handshake
sni: Server Name Indication sent by the proxy during the TLS handshake
alpn_proto_negotiated: The negotiated application protocol
tls_version: TLS version
via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)
timestamp_start: Connection start timestamp
timestamp_tcp_setup: TCP ACK received timestamp
@ -155,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
self.alpn_proto_negotiated = None
self.tls_version = None
self.via = None
self.timestamp_start = None
self.timestamp_end = None
@ -166,19 +178,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __repr__(self):
if self.ssl_established and self.sni:
ssl = "[ssl: {0}] ".format(self.sni)
tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni)
elif self.ssl_established:
ssl = "[ssl] "
tls = "[{}] ".format(self.tls_version or "TLS")
else:
ssl = ""
tls = ""
if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
)
else:
alpn = ""
return "<ServerConnection: {ssl}{alpn}{host}:{port}>".format(
ssl=ssl,
return "<ServerConnection: {tls}{alpn}{host}:{port}>".format(
tls=tls,
alpn=alpn,
host=self.address[0],
port=self.address[1],
@ -188,6 +200,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def tls_established(self):
return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict(
address=tuple,
ip_address=tuple,
@ -196,6 +212,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
cert=certs.SSLCert,
sni=str,
alpn_proto_negotiated=bytes,
tls_version=str,
timestamp_start=float,
timestamp_tcp_setup=float,
timestamp_ssl_setup=float,
@ -211,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
@classmethod
def make_dummy(cls, address):
return cls.from_state(dict(
address=dict(address=address, use_ipv6=False),
ip_address=dict(address=address, use_ipv6=False),
address=address,
ip_address=address,
cert=None,
sni=None,
alpn_proto_negotiated=None,
source_address=dict(address=('', 0), use_ipv6=False),
tls_version=None,
source_address=('', 0),
ssl_established=False,
timestamp_start=None,
timestamp_tcp_setup=None,
@ -253,6 +271,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
self.sni = sni
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
self.tls_version = self.connection.get_protocol_version_name()
self.timestamp_ssl_setup = time.time()
def finish(self):

View File

@ -99,6 +99,9 @@ def convert_100_200(data):
def convert_200_300(data):
data["version"] = (3, 0, 0)
data["client_conn"]["mitmcert"] = None
data["server_conn"]["tls_version"] = None
if data["server_conn"]["via"]:
data["server_conn"]["via"]["tls_version"] = None
return data

View File

@ -1,3 +1,5 @@
import io
from mitmproxy.net import websockets
from mitmproxy.test import tutils
from mitmproxy import tcp
@ -156,6 +158,8 @@ def tclient_conn():
tls_version="TLSv1.2",
))
c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c
@ -175,9 +179,12 @@ def tserver_conn():
ssl_established=False,
sni="address",
alpn_proto_negotiated=None,
tls_version=None,
via=None,
))
c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c

View File

@ -35,7 +35,6 @@ exclude =
mitmproxy/proxy/server.py
mitmproxy/tools/
mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/controller.py
mitmproxy/export.py
mitmproxy/flow.py
@ -52,7 +51,6 @@ exclude =
mitmproxy/addons/onboardingapp/app.py
mitmproxy/addons/termlog.py
mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/contentviews/base.py
mitmproxy/contentviews/wbxml.py
mitmproxy/contentviews/xml_html.py

View File

@ -1,13 +1,57 @@
import socket
import os
import threading
import ssl
import OpenSSL
import pytest
from unittest import mock
from mitmproxy import connections
from mitmproxy import exceptions
from mitmproxy.net import tcp
from mitmproxy.net.http import http1
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from .net import tservers
from pathod import test
class TestClientConnection:
def test_send(self):
c = tflow.tclient_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tclient_conn()
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS' not in repr(c)
c.alpn_proto_negotiated = None
c.tls_established = True
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
def test_tls_established_property(self):
c = tflow.tclient_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ClientConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_state(self):
c = tflow.tclient_conn()
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
@ -24,44 +68,143 @@ class TestClientConnection:
c3 = c.copy()
assert c3.get_state() == c.get_state()
assert str(c)
class TestServerConnection:
def test_send(self):
c = tflow.tserver_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tserver_conn()
c.sni = 'foobar'
c.tls_established = True
c.alpn_proto_negotiated = b'h2'
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS: foobar' in repr(c)
c.sni = None
c.tls_established = True
c.alpn_proto_negotiated = None
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
c.sni = None
c.tls_established = False
assert 'TLS' not in repr(c)
def test_tls_established_property(self):
c = tflow.tserver_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ServerConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_simple(self):
self.d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
sc.connect()
d = test.Daemon()
c = connections.ServerConnection((d.IFACE, d.port))
c.connect()
f = tflow.tflow()
f.server_conn = sc
f.server_conn = c
f.request.path = "/p/200:da"
# use this protocol just to assemble - not for actual sending
sc.wfile.write(http1.assemble_request(f.request))
sc.wfile.flush()
c.wfile.write(http1.assemble_request(f.request))
c.wfile.flush()
assert http1.read_response(sc.rfile, f.request, 1000)
assert self.d.last_log()
assert http1.read_response(c.rfile, f.request, 1000)
assert d.last_log()
sc.finish()
self.d.shutdown()
c.finish()
d.shutdown()
def test_terminate_error(self):
self.d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
sc.connect()
sc.connection = mock.Mock()
sc.connection.recv = mock.Mock(return_value=False)
sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
sc.finish()
self.d.shutdown()
d = test.Daemon()
c = connections.ServerConnection((d.IFACE, d.port))
c.connect()
c.connection = mock.Mock()
c.connection.recv = mock.Mock(return_value=False)
c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
c.finish()
d.shutdown()
def test_repr(self):
sc = tflow.tserver_conn()
assert "address:22" in repr(sc)
assert "ssl" not in repr(sc)
sc.ssl_established = True
assert "ssl" in repr(sc)
sc.sni = "foo"
assert "foo" in repr(sc)
def test_sni(self):
c = connections.ServerConnection(('', 1234))
with pytest.raises(ValueError, matches='sni must be str, not '):
c.establish_ssl(None, b'foobar')
class TestClientConnectionTLS:
@pytest.mark.parametrize("sni", [
None,
"example.com"
])
def test_tls_with_sni(self, sni):
address = ('127.0.0.1', 0)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(address)
sock.listen()
address = sock.getsockname()
def client_run():
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
s = socket.create_connection(address)
s = ctx.wrap_socket(s, server_hostname=sni)
s.send(b'foobar')
s.shutdown(socket.SHUT_RDWR)
threading.Thread(target=client_run).start()
connection, client_address = sock.accept()
c = connections.ClientConnection(connection, client_address, None)
cert = tutils.test_data.path("mitmproxy/net/data/server.crt")
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read())
c.convert_to_ssl(cert, key)
assert c.connected()
assert c.sni == sni
assert c.tls_established
assert c.rfile.read(6) == b'foobar'
c.finish()
class TestServerConnectionTLS(tservers.ServerTestBase):
ssl = True
class handler(tcp.BaseHandler):
def handle(self):
self.finish()
@pytest.mark.parametrize("clientcert", [
None,
tutils.test_data.path("mitmproxy/data/clientcert"),
os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"),
])
def test_tls(self, clientcert):
c = connections.ServerConnection(("127.0.0.1", self.port))
c.connect()
c.establish_ssl(clientcert, "foo.com")
assert c.connected()
assert c.sni == "foo.com"
assert c.tls_established
c.close()
c.finish()