proxy.py/tests/plugin/test_http_proxy_plugins_wit...

212 lines
7.8 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import unittest
import socket
import selectors
import ssl
from unittest import mock
from typing import Any, cast
from proxy.proxy import Proxy
from proxy.common.utils import bytes_
from proxy.common.utils import build_http_request, build_http_response
from proxy.core.connection import TcpClientConnection, TcpServerConnection
from proxy.http.codes import httpStatusCodes
from proxy.http.methods import httpMethods
from proxy.http.handler import HttpProtocolHandler
from proxy.http.proxy import HttpProxyPlugin
from .utils import get_plugin_by_test_name
class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
@mock.patch('ssl.wrap_socket')
@mock.patch('ssl.create_default_context')
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
@mock.patch('proxy.http.proxy.server.gen_public_key')
@mock.patch('proxy.http.proxy.server.gen_csr')
@mock.patch('proxy.http.proxy.server.sign_csr')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def setUp(self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock,
mock_sign_csr: mock.Mock,
mock_gen_csr: mock.Mock,
mock_gen_public_key: mock.Mock,
mock_server_conn: mock.Mock,
mock_ssl_context: mock.Mock,
mock_ssl_wrap: mock.Mock) -> None:
self.mock_fromfd = mock_fromfd
self.mock_selector = mock_selector
self.mock_sign_csr = mock_sign_csr
self.mock_gen_csr = mock_gen_csr
self.mock_gen_public_key = mock_gen_public_key
self.mock_server_conn = mock_server_conn
self.mock_ssl_context = mock_ssl_context
self.mock_ssl_wrap = mock_ssl_wrap
self.mock_sign_csr.return_value = True
self.mock_gen_csr.return_value = True
self.mock_gen_public_key.return_value = True
self.fileno = 10
self._addr = ('127.0.0.1', 54382)
self.flags = Proxy.initialize(
ca_cert_file='ca-cert.pem',
ca_key_file='ca-key.pem',
ca_signing_key_file='ca-signing-key.pem',)
self.plugin = mock.MagicMock()
plugin = get_plugin_by_test_name(self._testMethodName)
self.flags.plugins = {
b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
b'HttpProxyBasePlugin': [plugin],
}
self._conn = mock.MagicMock(spec=socket.socket)
mock_fromfd.return_value = self._conn
self.protocol_handler = HttpProtocolHandler(
TcpClientConnection(self._conn, self._addr), flags=self.flags)
self.protocol_handler.initialize()
self.server = self.mock_server_conn.return_value
self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection
self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
self.mock_ssl_wrap.return_value = self.client_ssl_connection
def has_buffer() -> bool:
return cast(bool, self.server.queue.called)
def closed() -> bool:
return not self.server.connect.called
def mock_connection() -> Any:
if self.mock_ssl_context.return_value.wrap_socket.called:
return self.server_ssl_connection
return self._conn
# Do not mock the original wrap method
self.server.wrap.side_effect = \
lambda x, y: TcpServerConnection.wrap(self.server, x, y)
self.server.has_buffer.side_effect = has_buffer
type(self.server).closed = mock.PropertyMock(side_effect=closed)
type(
self.server).connection = mock.PropertyMock(
side_effect=mock_connection)
self.mock_selector.return_value.select.side_effect = [
[(selectors.SelectorKey(
fileobj=self._conn,
fd=self._conn.fileno,
events=selectors.EVENT_READ,
data=None), selectors.EVENT_READ)],
[(selectors.SelectorKey(
fileobj=self.client_ssl_connection,
fd=self.client_ssl_connection.fileno,
events=selectors.EVENT_READ,
data=None), selectors.EVENT_READ)],
[(selectors.SelectorKey(
fileobj=self.server_ssl_connection,
fd=self.server_ssl_connection.fileno,
events=selectors.EVENT_WRITE,
data=None), selectors.EVENT_WRITE)],
[(selectors.SelectorKey(
fileobj=self.server_ssl_connection,
fd=self.server_ssl_connection.fileno,
events=selectors.EVENT_READ,
data=None), selectors.EVENT_READ)], ]
# Connect
def send(raw: bytes) -> int:
return len(raw)
self._conn.send.side_effect = send
self._conn.recv.return_value = build_http_request(
httpMethods.CONNECT, b'uni.corn:443'
)
self.protocol_handler.run_once()
self.assertEqual(self.mock_sign_csr.call_count, 1)
self.assertEqual(self.mock_gen_csr.call_count, 1)
self.assertEqual(self.mock_gen_public_key.call_count, 1)
self.mock_server_conn.assert_called_once_with('uni.corn', 443)
self.server.connect.assert_called()
self.assertEqual(
self.protocol_handler.client.connection,
self.client_ssl_connection)
self.assertEqual(self.server.connection, self.server_ssl_connection)
self._conn.send.assert_called_with(
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT
)
self.assertFalse(self.protocol_handler.client.has_buffer())
def test_modify_post_data_plugin(self) -> None:
original = b'{"key": "value"}'
modified = b'{"key": "modified"}'
self.client_ssl_connection.recv.return_value = build_http_request(
b'POST', b'/',
headers={
b'Host': b'uni.corn',
b'Content-Type': b'application/x-www-form-urlencoded',
b'Content-Length': bytes_(len(original)),
},
body=original
)
self.protocol_handler.run_once()
self.server.queue.assert_called_with(
build_http_request(
b'POST', b'/',
headers={
b'Host': b'uni.corn',
b'Content-Length': bytes_(len(modified)),
b'Content-Type': b'application/json',
},
body=modified
)
)
def test_man_in_the_middle_plugin(self) -> None:
request = build_http_request(
b'GET', b'/',
headers={
b'Host': b'uni.corn',
}
)
self.client_ssl_connection.recv.return_value = request
# Client read
self.protocol_handler.run_once()
self.server.queue.assert_called_once_with(request)
# Server write
self.protocol_handler.run_once()
self.server.flush.assert_called_once()
# Server read
self.server.recv.return_value = \
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Original Response From Upstream')
self.protocol_handler.run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Hello from man in the middle')
)