connections: coverage++ (#2064)
This commit is contained in:
parent
b33d568e04
commit
9b6986ea87
|
@ -54,14 +54,20 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||||
return bool(self.connection) and not self.finished
|
return bool(self.connection) and not self.finished
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
if self.ssl_established:
|
||||||
|
tls = "[{}] ".format(self.tls_version)
|
||||||
|
else:
|
||||||
|
tls = ""
|
||||||
|
|
||||||
if self.alpn_proto_negotiated:
|
if self.alpn_proto_negotiated:
|
||||||
alpn = "[ALPN: {}] ".format(
|
alpn = "[ALPN: {}] ".format(
|
||||||
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
|
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
alpn = ""
|
alpn = ""
|
||||||
return "<ClientConnection: {ssl}{alpn}{host}:{port}>".format(
|
|
||||||
ssl="[ssl] " if self.ssl_established else "",
|
return "<ClientConnection: {tls}{alpn}{host}:{port}>".format(
|
||||||
|
tls=tls,
|
||||||
alpn=alpn,
|
alpn=alpn,
|
||||||
host=self.address[0],
|
host=self.address[0],
|
||||||
port=self.address[1],
|
port=self.address[1],
|
||||||
|
@ -71,6 +77,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||||
def tls_established(self):
|
def tls_established(self):
|
||||||
return self.ssl_established
|
return self.ssl_established
|
||||||
|
|
||||||
|
@tls_established.setter
|
||||||
|
def tls_established(self, value):
|
||||||
|
self.ssl_established = value
|
||||||
|
|
||||||
_stateobject_attributes = dict(
|
_stateobject_attributes = dict(
|
||||||
address=tuple,
|
address=tuple,
|
||||||
ssl_established=bool,
|
ssl_established=bool,
|
||||||
|
@ -100,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_dummy(cls, address):
|
def make_dummy(cls, address):
|
||||||
return cls.from_state(dict(
|
return cls.from_state(dict(
|
||||||
address=dict(address=address, use_ipv6=False),
|
address=address,
|
||||||
clientcert=None,
|
clientcert=None,
|
||||||
mitmcert=None,
|
mitmcert=None,
|
||||||
ssl_established=False,
|
ssl_established=False,
|
||||||
|
@ -144,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
cert: The certificate presented by the remote during the TLS handshake
|
cert: The certificate presented by the remote during the TLS handshake
|
||||||
sni: Server Name Indication sent by the proxy during the TLS handshake
|
sni: Server Name Indication sent by the proxy during the TLS handshake
|
||||||
alpn_proto_negotiated: The negotiated application protocol
|
alpn_proto_negotiated: The negotiated application protocol
|
||||||
|
tls_version: TLS version
|
||||||
via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)
|
via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)
|
||||||
timestamp_start: Connection start timestamp
|
timestamp_start: Connection start timestamp
|
||||||
timestamp_tcp_setup: TCP ACK received timestamp
|
timestamp_tcp_setup: TCP ACK received timestamp
|
||||||
|
@ -155,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
|
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
|
||||||
|
|
||||||
self.alpn_proto_negotiated = None
|
self.alpn_proto_negotiated = None
|
||||||
|
self.tls_version = None
|
||||||
self.via = None
|
self.via = None
|
||||||
self.timestamp_start = None
|
self.timestamp_start = None
|
||||||
self.timestamp_end = None
|
self.timestamp_end = None
|
||||||
|
@ -166,19 +178,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.ssl_established and self.sni:
|
if self.ssl_established and self.sni:
|
||||||
ssl = "[ssl: {0}] ".format(self.sni)
|
tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni)
|
||||||
elif self.ssl_established:
|
elif self.ssl_established:
|
||||||
ssl = "[ssl] "
|
tls = "[{}] ".format(self.tls_version or "TLS")
|
||||||
else:
|
else:
|
||||||
ssl = ""
|
tls = ""
|
||||||
if self.alpn_proto_negotiated:
|
if self.alpn_proto_negotiated:
|
||||||
alpn = "[ALPN: {}] ".format(
|
alpn = "[ALPN: {}] ".format(
|
||||||
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
|
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
alpn = ""
|
alpn = ""
|
||||||
return "<ServerConnection: {ssl}{alpn}{host}:{port}>".format(
|
return "<ServerConnection: {tls}{alpn}{host}:{port}>".format(
|
||||||
ssl=ssl,
|
tls=tls,
|
||||||
alpn=alpn,
|
alpn=alpn,
|
||||||
host=self.address[0],
|
host=self.address[0],
|
||||||
port=self.address[1],
|
port=self.address[1],
|
||||||
|
@ -188,6 +200,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
def tls_established(self):
|
def tls_established(self):
|
||||||
return self.ssl_established
|
return self.ssl_established
|
||||||
|
|
||||||
|
@tls_established.setter
|
||||||
|
def tls_established(self, value):
|
||||||
|
self.ssl_established = value
|
||||||
|
|
||||||
_stateobject_attributes = dict(
|
_stateobject_attributes = dict(
|
||||||
address=tuple,
|
address=tuple,
|
||||||
ip_address=tuple,
|
ip_address=tuple,
|
||||||
|
@ -196,6 +212,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
cert=certs.SSLCert,
|
cert=certs.SSLCert,
|
||||||
sni=str,
|
sni=str,
|
||||||
alpn_proto_negotiated=bytes,
|
alpn_proto_negotiated=bytes,
|
||||||
|
tls_version=str,
|
||||||
timestamp_start=float,
|
timestamp_start=float,
|
||||||
timestamp_tcp_setup=float,
|
timestamp_tcp_setup=float,
|
||||||
timestamp_ssl_setup=float,
|
timestamp_ssl_setup=float,
|
||||||
|
@ -211,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_dummy(cls, address):
|
def make_dummy(cls, address):
|
||||||
return cls.from_state(dict(
|
return cls.from_state(dict(
|
||||||
address=dict(address=address, use_ipv6=False),
|
address=address,
|
||||||
ip_address=dict(address=address, use_ipv6=False),
|
ip_address=address,
|
||||||
cert=None,
|
cert=None,
|
||||||
sni=None,
|
sni=None,
|
||||||
alpn_proto_negotiated=None,
|
alpn_proto_negotiated=None,
|
||||||
source_address=dict(address=('', 0), use_ipv6=False),
|
tls_version=None,
|
||||||
|
source_address=('', 0),
|
||||||
ssl_established=False,
|
ssl_established=False,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_tcp_setup=None,
|
timestamp_tcp_setup=None,
|
||||||
|
@ -253,6 +271,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||||
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
|
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
|
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
|
||||||
|
self.tls_version = self.connection.get_protocol_version_name()
|
||||||
self.timestamp_ssl_setup = time.time()
|
self.timestamp_ssl_setup = time.time()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
|
|
|
@ -99,6 +99,9 @@ def convert_100_200(data):
|
||||||
def convert_200_300(data):
|
def convert_200_300(data):
|
||||||
data["version"] = (3, 0, 0)
|
data["version"] = (3, 0, 0)
|
||||||
data["client_conn"]["mitmcert"] = None
|
data["client_conn"]["mitmcert"] = None
|
||||||
|
data["server_conn"]["tls_version"] = None
|
||||||
|
if data["server_conn"]["via"]:
|
||||||
|
data["server_conn"]["via"]["tls_version"] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import io
|
||||||
|
|
||||||
from mitmproxy.net import websockets
|
from mitmproxy.net import websockets
|
||||||
from mitmproxy.test import tutils
|
from mitmproxy.test import tutils
|
||||||
from mitmproxy import tcp
|
from mitmproxy import tcp
|
||||||
|
@ -156,6 +158,8 @@ def tclient_conn():
|
||||||
tls_version="TLSv1.2",
|
tls_version="TLSv1.2",
|
||||||
))
|
))
|
||||||
c.reply = controller.DummyReply()
|
c.reply = controller.DummyReply()
|
||||||
|
c.rfile = io.BytesIO()
|
||||||
|
c.wfile = io.BytesIO()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,9 +179,12 @@ def tserver_conn():
|
||||||
ssl_established=False,
|
ssl_established=False,
|
||||||
sni="address",
|
sni="address",
|
||||||
alpn_proto_negotiated=None,
|
alpn_proto_negotiated=None,
|
||||||
|
tls_version=None,
|
||||||
via=None,
|
via=None,
|
||||||
))
|
))
|
||||||
c.reply = controller.DummyReply()
|
c.reply = controller.DummyReply()
|
||||||
|
c.rfile = io.BytesIO()
|
||||||
|
c.wfile = io.BytesIO()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ exclude =
|
||||||
mitmproxy/proxy/server.py
|
mitmproxy/proxy/server.py
|
||||||
mitmproxy/tools/
|
mitmproxy/tools/
|
||||||
mitmproxy/certs.py
|
mitmproxy/certs.py
|
||||||
mitmproxy/connections.py
|
|
||||||
mitmproxy/controller.py
|
mitmproxy/controller.py
|
||||||
mitmproxy/export.py
|
mitmproxy/export.py
|
||||||
mitmproxy/flow.py
|
mitmproxy/flow.py
|
||||||
|
@ -52,7 +51,6 @@ exclude =
|
||||||
mitmproxy/addons/onboardingapp/app.py
|
mitmproxy/addons/onboardingapp/app.py
|
||||||
mitmproxy/addons/termlog.py
|
mitmproxy/addons/termlog.py
|
||||||
mitmproxy/certs.py
|
mitmproxy/certs.py
|
||||||
mitmproxy/connections.py
|
|
||||||
mitmproxy/contentviews/base.py
|
mitmproxy/contentviews/base.py
|
||||||
mitmproxy/contentviews/wbxml.py
|
mitmproxy/contentviews/wbxml.py
|
||||||
mitmproxy/contentviews/xml_html.py
|
mitmproxy/contentviews/xml_html.py
|
||||||
|
|
|
@ -1,13 +1,57 @@
|
||||||
|
import socket
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import ssl
|
||||||
|
import OpenSSL
|
||||||
|
import pytest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from mitmproxy import connections
|
from mitmproxy import connections
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
|
from mitmproxy.net import tcp
|
||||||
from mitmproxy.net.http import http1
|
from mitmproxy.net.http import http1
|
||||||
from mitmproxy.test import tflow
|
from mitmproxy.test import tflow
|
||||||
|
from mitmproxy.test import tutils
|
||||||
|
from .net import tservers
|
||||||
from pathod import test
|
from pathod import test
|
||||||
|
|
||||||
|
|
||||||
class TestClientConnection:
|
class TestClientConnection:
|
||||||
|
|
||||||
|
def test_send(self):
|
||||||
|
c = tflow.tclient_conn()
|
||||||
|
c.send(b'foobar')
|
||||||
|
c.send([b'foo', b'bar'])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
c.send('string')
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
c.send(['string', 'not'])
|
||||||
|
assert c.wfile.getvalue() == b'foobarfoobar'
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
c = tflow.tclient_conn()
|
||||||
|
assert 'address:22' in repr(c)
|
||||||
|
assert 'ALPN' in repr(c)
|
||||||
|
assert 'TLS' not in repr(c)
|
||||||
|
|
||||||
|
c.alpn_proto_negotiated = None
|
||||||
|
c.tls_established = True
|
||||||
|
assert 'ALPN' not in repr(c)
|
||||||
|
assert 'TLS' in repr(c)
|
||||||
|
|
||||||
|
def test_tls_established_property(self):
|
||||||
|
c = tflow.tclient_conn()
|
||||||
|
c.tls_established = True
|
||||||
|
assert c.ssl_established
|
||||||
|
assert c.tls_established
|
||||||
|
c.tls_established = False
|
||||||
|
assert not c.ssl_established
|
||||||
|
assert not c.tls_established
|
||||||
|
|
||||||
|
def test_make_dummy(self):
|
||||||
|
c = connections.ClientConnection.make_dummy(('foobar', 1234))
|
||||||
|
assert c.address == ('foobar', 1234)
|
||||||
|
|
||||||
def test_state(self):
|
def test_state(self):
|
||||||
c = tflow.tclient_conn()
|
c = tflow.tclient_conn()
|
||||||
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
|
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
|
||||||
|
@ -24,44 +68,143 @@ class TestClientConnection:
|
||||||
c3 = c.copy()
|
c3 = c.copy()
|
||||||
assert c3.get_state() == c.get_state()
|
assert c3.get_state() == c.get_state()
|
||||||
|
|
||||||
assert str(c)
|
|
||||||
|
|
||||||
|
|
||||||
class TestServerConnection:
|
class TestServerConnection:
|
||||||
|
|
||||||
|
def test_send(self):
|
||||||
|
c = tflow.tserver_conn()
|
||||||
|
c.send(b'foobar')
|
||||||
|
c.send([b'foo', b'bar'])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
c.send('string')
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
c.send(['string', 'not'])
|
||||||
|
assert c.wfile.getvalue() == b'foobarfoobar'
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
c = tflow.tserver_conn()
|
||||||
|
|
||||||
|
c.sni = 'foobar'
|
||||||
|
c.tls_established = True
|
||||||
|
c.alpn_proto_negotiated = b'h2'
|
||||||
|
assert 'address:22' in repr(c)
|
||||||
|
assert 'ALPN' in repr(c)
|
||||||
|
assert 'TLS: foobar' in repr(c)
|
||||||
|
|
||||||
|
c.sni = None
|
||||||
|
c.tls_established = True
|
||||||
|
c.alpn_proto_negotiated = None
|
||||||
|
assert 'ALPN' not in repr(c)
|
||||||
|
assert 'TLS' in repr(c)
|
||||||
|
|
||||||
|
c.sni = None
|
||||||
|
c.tls_established = False
|
||||||
|
assert 'TLS' not in repr(c)
|
||||||
|
|
||||||
|
def test_tls_established_property(self):
|
||||||
|
c = tflow.tserver_conn()
|
||||||
|
c.tls_established = True
|
||||||
|
assert c.ssl_established
|
||||||
|
assert c.tls_established
|
||||||
|
c.tls_established = False
|
||||||
|
assert not c.ssl_established
|
||||||
|
assert not c.tls_established
|
||||||
|
|
||||||
|
def test_make_dummy(self):
|
||||||
|
c = connections.ServerConnection.make_dummy(('foobar', 1234))
|
||||||
|
assert c.address == ('foobar', 1234)
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
self.d = test.Daemon()
|
d = test.Daemon()
|
||||||
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
|
c = connections.ServerConnection((d.IFACE, d.port))
|
||||||
sc.connect()
|
c.connect()
|
||||||
f = tflow.tflow()
|
f = tflow.tflow()
|
||||||
f.server_conn = sc
|
f.server_conn = c
|
||||||
f.request.path = "/p/200:da"
|
f.request.path = "/p/200:da"
|
||||||
|
|
||||||
# use this protocol just to assemble - not for actual sending
|
# use this protocol just to assemble - not for actual sending
|
||||||
sc.wfile.write(http1.assemble_request(f.request))
|
c.wfile.write(http1.assemble_request(f.request))
|
||||||
sc.wfile.flush()
|
c.wfile.flush()
|
||||||
|
|
||||||
assert http1.read_response(sc.rfile, f.request, 1000)
|
assert http1.read_response(c.rfile, f.request, 1000)
|
||||||
assert self.d.last_log()
|
assert d.last_log()
|
||||||
|
|
||||||
sc.finish()
|
c.finish()
|
||||||
self.d.shutdown()
|
d.shutdown()
|
||||||
|
|
||||||
def test_terminate_error(self):
|
def test_terminate_error(self):
|
||||||
self.d = test.Daemon()
|
d = test.Daemon()
|
||||||
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
|
c = connections.ServerConnection((d.IFACE, d.port))
|
||||||
sc.connect()
|
c.connect()
|
||||||
sc.connection = mock.Mock()
|
c.connection = mock.Mock()
|
||||||
sc.connection.recv = mock.Mock(return_value=False)
|
c.connection.recv = mock.Mock(return_value=False)
|
||||||
sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
|
c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
|
||||||
sc.finish()
|
c.finish()
|
||||||
self.d.shutdown()
|
d.shutdown()
|
||||||
|
|
||||||
def test_repr(self):
|
def test_sni(self):
|
||||||
sc = tflow.tserver_conn()
|
c = connections.ServerConnection(('', 1234))
|
||||||
assert "address:22" in repr(sc)
|
with pytest.raises(ValueError, matches='sni must be str, not '):
|
||||||
assert "ssl" not in repr(sc)
|
c.establish_ssl(None, b'foobar')
|
||||||
sc.ssl_established = True
|
|
||||||
assert "ssl" in repr(sc)
|
|
||||||
sc.sni = "foo"
|
class TestClientConnectionTLS:
|
||||||
assert "foo" in repr(sc)
|
|
||||||
|
@pytest.mark.parametrize("sni", [
|
||||||
|
None,
|
||||||
|
"example.com"
|
||||||
|
])
|
||||||
|
def test_tls_with_sni(self, sni):
|
||||||
|
address = ('127.0.0.1', 0)
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.bind(address)
|
||||||
|
sock.listen()
|
||||||
|
address = sock.getsockname()
|
||||||
|
|
||||||
|
def client_run():
|
||||||
|
ctx = ssl.create_default_context()
|
||||||
|
ctx.check_hostname = False
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
s = socket.create_connection(address)
|
||||||
|
s = ctx.wrap_socket(s, server_hostname=sni)
|
||||||
|
s.send(b'foobar')
|
||||||
|
s.shutdown(socket.SHUT_RDWR)
|
||||||
|
threading.Thread(target=client_run).start()
|
||||||
|
|
||||||
|
connection, client_address = sock.accept()
|
||||||
|
c = connections.ClientConnection(connection, client_address, None)
|
||||||
|
|
||||||
|
cert = tutils.test_data.path("mitmproxy/net/data/server.crt")
|
||||||
|
key = OpenSSL.crypto.load_privatekey(
|
||||||
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
|
open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read())
|
||||||
|
c.convert_to_ssl(cert, key)
|
||||||
|
assert c.connected()
|
||||||
|
assert c.sni == sni
|
||||||
|
assert c.tls_established
|
||||||
|
assert c.rfile.read(6) == b'foobar'
|
||||||
|
c.finish()
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerConnectionTLS(tservers.ServerTestBase):
|
||||||
|
ssl = True
|
||||||
|
|
||||||
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
self.finish()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("clientcert", [
|
||||||
|
None,
|
||||||
|
tutils.test_data.path("mitmproxy/data/clientcert"),
|
||||||
|
os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"),
|
||||||
|
])
|
||||||
|
def test_tls(self, clientcert):
|
||||||
|
c = connections.ServerConnection(("127.0.0.1", self.port))
|
||||||
|
c.connect()
|
||||||
|
c.establish_ssl(clientcert, "foo.com")
|
||||||
|
assert c.connected()
|
||||||
|
assert c.sni == "foo.com"
|
||||||
|
assert c.tls_established
|
||||||
|
c.close()
|
||||||
|
c.finish()
|
||||||
|
|
Loading…
Reference in New Issue