diff --git a/.dockerignore b/.dockerignore index c0fcf9b7..45aff3ba 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,3 +3,4 @@ # Except proxy.py !proxy.py +!requirements.txt diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 6bab40e4..7815f0af 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -27,8 +27,8 @@ jobs: run: | # The GitHub editor is 127 chars wide # W504 screams for line break after binary operators - flake8 --ignore=W504 --max-line-length=127 proxy.py tests.py + flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py # mypy compliance check - mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py + mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py - name: Run Tests run: pytest tests.py diff --git a/Dockerfile b/Dockerfile index d55321e6..291ae291 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,12 +5,13 @@ LABEL com.abhinavsingh.name="abhinavsingh/proxy.py" \ com.abhinavsingh.vcs-url="https://github.com/abhinavsingh/proxy.py" \ com.abhinavsingh.docker.cmd="docker run -it --rm -p 8899:8899 abhinavsingh/proxy.py" -RUN pip install --upgrade pip && pip install typing-extensions==3.7.4 - -COPY proxy.py /app/ -EXPOSE 8899/tcp - WORKDIR /app +COPY requirements.txt . +COPY proxy.py . + +RUN pip install --upgrade pip && pip install -r requirements.txt + +EXPOSE 8899/tcp ENTRYPOINT [ "./proxy.py" ] CMD [ "--hostname=0.0.0.0", \ "--port=8899" ] diff --git a/Makefile b/Makefile index 2cbea1f9..8cba0b7b 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,8 @@ coverage: open htmlcov/index.html lint: - flake8 --ignore=W504 --max-line-length=127 proxy.py tests.py - mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py + flake8 --ignore=W504 --max-line-length=127 proxy.py plugin_examples.py tests.py setup.py + mypy --strict --ignore-missing-imports proxy.py plugin_examples.py tests.py setup.py autopep8: autopep8 --recursive --in-place --aggressive proxy.py diff --git a/README.md b/README.md index 2b14eccc..e7d9e50d 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ See [flags](#flags) for full list of available configuration options. ## Docker image - $ docker run -it -p 8899:8899 --rm abhinavsingh/proxy.py:v1.0.0 + $ docker run -it -p 8899:8899 --rm abhinavsingh/proxy.py:latest By default `docker` binary is started with IPv4 networking flags: @@ -177,7 +177,7 @@ For example, to check `proxy.py --version`: $ docker run -it \ -p 8899:8899 \ - --rm abhinavsingh/proxy.py:v1.0.0 \ + --rm abhinavsingh/proxy.py:latest \ --version [![WARNING](https://img.shields.io/static/v1?label=MacOS&message=warning&color=red)](https://github.com/moby/vpnkit/issues/469) diff --git a/proxy.py b/proxy.py index c8ef6f98..1c924dbb 100755 --- a/proxy.py +++ b/proxy.py @@ -38,8 +38,8 @@ import time from abc import ABC, abstractmethod from multiprocessing import connection from multiprocessing.reduction import send_handle, recv_handle -from typing import Any, Dict, List, Tuple, Optional, Union, NamedTuple, Callable, TYPE_CHECKING, Type from types import TracebackType +from typing import Any, Dict, List, Tuple, Optional, Union, NamedTuple, Callable, TYPE_CHECKING, Type, cast from urllib import parse as urlparse from typing_extensions import Protocol @@ -50,7 +50,7 @@ if os.name != 'nt': PROXY_PY_DIR = os.path.dirname(os.path.realpath(__file__)) PROXY_PY_START_TIME = time.time() -VERSION = (1, 1, 0) +VERSION = (1, 1, 1) __version__ = '.'.join(map(str, VERSION[0:3])) __description__ = 'Lightweight, Programmable, TLS interceptor Proxy for HTTP(S), HTTP2, ' \ 'WebSockets protocols in a single Python file.' @@ -1246,34 +1246,37 @@ class HttpProxyPlugin(ProtocolHandlerPlugin): else: return raw - def generate_upstream_certificate(self) -> Optional[str]: - if self.config.ca_cert_dir and self.config.ca_signing_key_file and \ - self.config.ca_cert_file and self.config.ca_key_file: - with self.lock: - cert_file_path = os.path.join( - self.config.ca_cert_dir, - '%s.pem' % - text_( - self.request.host)) - if not os.path.isfile(cert_file_path): - logger.debug('Generating certificates %s', cert_file_path) - # TODO: Use ssl.get_server_certificate to populate generated certificate metadata - # Currently we only set CN= field for generated certificates. - gen_cert = subprocess.Popen( - ['/usr/bin/openssl', 'req', '-new', '-key', self.config.ca_signing_key_file, '-subj', - '/CN=%s' % text_(self.request.host)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - sign_cert = subprocess.Popen( - ['/usr/bin/openssl', 'x509', '-req', '-days', '365', '-CA', self.config.ca_cert_file, '-CAkey', - self.config.ca_key_file, '-set_serial', str(int(time.time())), '-out', cert_file_path], - stdin=gen_cert.stdout, - stderr=subprocess.PIPE) - # TODO: Ensure sign_cert success. - sign_cert.communicate(timeout=10) - return cert_file_path - else: - return None + def generate_upstream_certificate(self, _certificate: Optional[Dict[str, Any]]) -> Optional[str]: + if not (self.config.ca_cert_dir and self.config.ca_signing_key_file and + self.config.ca_cert_file and self.config.ca_key_file): + raise ProtocolException( + f'For certificate generation all the following flags are mandatory: ' + f'--ca-cert-file:{ self.config.ca_cert_file }, ' + f'--ca-key-file:{ self.config.ca_key_file }, ' + f'--ca-signing-key-file:{ self.config.ca_signing_key_file }') + with self.lock: + cert_file_path = os.path.join( + self.config.ca_cert_dir, + '%s.pem' % + text_( + self.request.host)) + if not os.path.isfile(cert_file_path): + logger.debug('Generating certificates %s', cert_file_path) + # TODO: Parse subject from certificate + # Currently we only set CN= field for generated certificates. + gen_cert = subprocess.Popen( + ['openssl', 'req', '-new', '-key', self.config.ca_signing_key_file, '-subj', + f'/C=/ST=/L=/O=/OU=/CN={ text_(self.request.host) }'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + sign_cert = subprocess.Popen( + ['openssl', 'x509', '-req', '-days', '365', '-CA', self.config.ca_cert_file, '-CAkey', + self.config.ca_key_file, '-set_serial', str(int(time.time())), '-out', cert_file_path], + stdin=gen_cert.stdout, + stderr=subprocess.PIPE) + # TODO: Ensure sign_cert success. + sign_cert.communicate(timeout=10) + return cert_file_path def on_request_complete(self) -> Union[socket.socket, bool]: if not self.request.has_upstream_server(): @@ -1295,27 +1298,34 @@ class HttpProxyPlugin(ProtocolHandlerPlugin): if self.request.method == httpMethods.CONNECT: self.client.queue( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - # If interception is enabled, generate server certificates + # If interception is enabled if self.config.ca_key_file and self.config.ca_cert_file and self.config.ca_signing_key_file: - generated_cert = self.generate_upstream_certificate() - if generated_cert: - if not (self.config.keyfile and self.config.certfile) and \ - self.server and isinstance(self.server.connection, socket.socket): - self.client._conn = ssl.wrap_socket( - self.client.connection, - server_side=True, - keyfile=self.config.ca_signing_key_file, - certfile=generated_cert) - # Wrap our connection to upstream server connection - ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - self.server._conn = ctx.wrap_socket( - self.server.connection, - server_hostname=text_(self.request.host)) - logger.info( - 'TLS interception using %s', generated_cert) - return self.client.connection + assert self.server is not None + assert isinstance(self.server.connection, socket.socket) + # Perform SSL/TLS handshake with upstream + ctx = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + self.server.connection.setblocking(True) + self.server._conn = ctx.wrap_socket( + self.server.connection, + server_hostname=text_(self.request.host)) + self.server.connection.setblocking(False) + assert isinstance(self.server.connection, ssl.SSLSocket) + # Generate certificate and perform handshake with client + generated_cert = self.generate_upstream_certificate( + cast(Dict[str, Any], self.server.connection.getpeercert())) + self.client.flush() + self.client.connection.setblocking(True) + self.client._conn = ssl.wrap_socket( + self.client.connection, + server_side=True, + keyfile=self.config.ca_signing_key_file, + certfile=generated_cert) + self.client.connection.setblocking(False) + logger.info( + 'TLS interception using %s', generated_cert) + return self.client.connection elif self.server: # - proxy-connection header is a mistake, it doesn't seem to be # officially documented in any specification, drop it. @@ -1351,6 +1361,7 @@ class HttpProxyPlugin(ProtocolHandlerPlugin): 'Connecting to upstream %s:%s' % (text_(host), port)) self.server.connect() + self.server.connection.setblocking(False) logger.debug( 'Connected to upstream %s:%s' % (text_(host), port)) @@ -1948,6 +1959,7 @@ class ProtocolHandler(threading.Thread): return datetime.datetime.utcnow() def initialize(self) -> None: + """Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins.""" conn = self.optionally_wrap_socket(self.client.connection) conn.setblocking(False) self.client = TcpClientConnection(conn=conn, addr=self.addr) @@ -2041,14 +2053,12 @@ class ProtocolHandler(threading.Thread): logger.debug( 'Updated client conn to %s', upgraded_sock) self.client._conn = upgraded_sock - # Update self.client.conn references for all - # plugins for plugin_ in self.plugins.values(): if plugin_ != plugin: plugin_.client._conn = upgraded_sock logger.debug( 'Upgraded client conn for plugin %s', str(plugin_)) - elif isinstance(upgraded_sock, bool) and upgraded_sock: + elif isinstance(upgraded_sock, bool) and upgraded_sock is True: return True except ProtocolException as e: logger.exception( diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..c05a7ea7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +typing-extensions==3.7.4 diff --git a/setup.py b/setup.py index a1ec7d7e..96a9a945 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ proxy.py ~~~~~~~~ ⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file. - + :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ @@ -77,5 +77,5 @@ setup( license=proxy.__license__, py_modules=['proxy'], scripts=['proxy.py'], - install_requires=['typing-extensions==3.7.4'], + install_requires=open('requirements.txt', 'r').read().strip().split(), ) diff --git a/tests.py b/tests.py index cb1eab65..3ab61ee6 100644 --- a/tests.py +++ b/tests.py @@ -21,7 +21,7 @@ import unittest from contextlib import closing from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Thread -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, Any from unittest import mock import proxy @@ -889,6 +889,13 @@ class TestHttpParser(unittest.TestCase): body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n')) self.assertEqual(self.parser.body, b'{"key":"value"}') self.assertEqual(self.parser.state, proxy.httpParserStates.COMPLETE) + self.assertEqual(self.parser.build(), proxy.build_http_request( + proxy.httpMethods.POST, b'/', + headers={ + b'Transfer-Encoding': b'chunked', + b'Content-Type': b'application/json', + }, + body=b'f\r\n{"key":"value"}\r\n0\r\n\r\n')) def assertDictContainsSubset(self, subset: Dict[bytes, Tuple[bytes, bytes]], dictionary: Dict[bytes, Tuple[bytes, bytes]]) -> None: @@ -1539,6 +1546,133 @@ class TestHttpProxyPlugin(unittest.TestCase): mock_server_conn.assert_not_called() +class TestHttpProxyTlsInterception(unittest.TestCase): + + @mock.patch('ssl.wrap_socket') + @mock.patch('ssl.create_default_context') + @mock.patch('proxy.TcpServerConnection') + @mock.patch('proxy.HttpProxyPlugin.generate_upstream_certificate') + @mock.patch('selectors.DefaultSelector') + @mock.patch('socket.fromfd') + def test_e2e( + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_generate_certificate: mock.Mock, + mock_server_conn: mock.Mock, + mock_ssl_context: mock.Mock, + mock_ssl_wrap: mock.Mock) -> None: + self.mock_fromfd = mock_fromfd + self.mock_selector = mock_selector + self.mock_generate_certificate = mock_generate_certificate + self.mock_server_conn = mock_server_conn + self.mock_ssl_context = mock_ssl_context + self.mock_ssl_wrap = mock_ssl_wrap + + ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) + self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection + self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) + plain_connection = mock.MagicMock(spec=socket.socket) + + def mock_connection() -> Any: + if self.mock_ssl_context.return_value.wrap_socket.called: + return ssl_connection + return plain_connection + + type(self.mock_server_conn.return_value).connection = \ + mock.PropertyMock(side_effect=mock_connection) + + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self.config = proxy.ProtocolConfig( + ca_cert_file='ca-cert.pem', + ca_key_file='ca-key.pem', + ca_signing_key_file='ca-signing-key.pem', + ) + self.plugin = mock.MagicMock() + self.proxy_plugin = mock.MagicMock() + self.config.plugins = { + b'ProtocolHandlerPlugin': [self.plugin, proxy.HttpProxyPlugin], + b'HttpProxyBasePlugin': [self.proxy_plugin], + } + self._conn = mock_fromfd.return_value + self.proxy = proxy.ProtocolHandler( + self.fileno, self._addr, config=self.config) + self.proxy.initialize() + + self.plugin.assert_called() + self.assertEqual(self.plugin.call_args[0][0], self.config) + self.assertEqual(self.plugin.call_args[0][1].connection, self._conn) + self.proxy_plugin.assert_called() + self.assertEqual(self.proxy_plugin.call_args[0][0], self.config) + self.assertEqual(self.proxy_plugin.call_args[0][1].connection, self._conn) + + connect_request = proxy.build_http_request( + proxy.httpMethods.CONNECT, b'super.secure:443', + headers={ + b'Host': b'super.secure:443', + }) + self._conn.recv.return_value = connect_request + + # Prepare mocked ProtocolHandlerPlugin + self.plugin.return_value.get_descriptors.return_value = ([], []) + self.plugin.return_value.write_to_descriptors.return_value = False + self.plugin.return_value.read_from_descriptors.return_value = False + self.plugin.return_value.on_client_data.side_effect = lambda raw: raw + self.plugin.return_value.on_request_complete.return_value = False + self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk + self.plugin.return_value.on_client_connection_close.return_value = None + + # Prepare mocked HttpProxyBasePlugin + self.proxy_plugin.return_value.before_upstream_connection.return_value = False + + self.mock_selector.return_value.select.side_effect = [ + [(selectors.SelectorKey( + fileobj=self._conn, + fd=self._conn.fileno, + events=selectors.EVENT_READ, + data=None), selectors.EVENT_READ)], ] + self.proxy.run_once() + + # Assert our mocked plugin invocations + self.plugin.return_value.get_descriptors.assert_called() + self.plugin.return_value.write_to_descriptors.assert_called_with([]) + self.plugin.return_value.on_client_data.assert_called_with(connect_request) + self.plugin.return_value.on_request_complete.assert_called() + self.plugin.return_value.read_from_descriptors.assert_called_with([self._conn]) + self.proxy_plugin.return_value.before_upstream_connection.assert_called() + self.proxy_plugin.return_value.on_upstream_connection.assert_called() + + self.mock_server_conn.assert_called_with('super.secure', 443) + self.mock_server_conn.return_value.connection.setblocking.assert_called_with(False) + + self.mock_ssl_context.assert_called_with(ssl.Purpose.SERVER_AUTH) + # self.assertEqual(self.mock_ssl_context.return_value.options, + # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1) + self.assertEqual(plain_connection.setblocking.call_count, 2) + self.mock_ssl_context.return_value.wrap_socket.assert_called_with( + plain_connection, server_hostname='super.secure') + self.assertEqual(ssl_connection.setblocking.call_count, 1) + self.assertEqual(self.mock_server_conn.return_value._conn, ssl_connection) + self.mock_generate_certificate.assert_called_with( + self.mock_server_conn.return_value.connection.getpeercert.return_value) + self._conn.send.assert_called_with(proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + self.mock_ssl_wrap.assert_called_with( + self._conn, + server_side=True, + keyfile=self.config.ca_signing_key_file, + certfile=self.mock_generate_certificate.return_value + ) + self.assertEqual(self._conn.setblocking.call_count, 2) + self.assertEqual(self.proxy.client.connection, self.mock_ssl_wrap.return_value) + + # Assert connection references for all other plugins is updated + self.assertEqual(self.plugin.return_value.client._conn, self.mock_ssl_wrap.return_value) + + # Currently proxy doesn't update it's own plugin + # self.assertEqual(self.proxy_plugin.return_value.client._conn, self._conn) + + class TestHttpRequestRejected(unittest.TestCase): def setUp(self) -> None: