diff --git a/examples/tcp_echo_client.py b/examples/tcp_echo_client.py new file mode 100644 index 00000000..decabb50 --- /dev/null +++ b/examples/tcp_echo_client.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from proxy.common.utils import socket_connection +from proxy.common.constants import DEFAULT_BUFFER_SIZE + +if __name__ == '__main__': + with socket_connection(('::', 12345)) as client: + while True: + client.send(b'hello') + data = client.recv(DEFAULT_BUFFER_SIZE) + if data is None: + break + print(data) diff --git a/examples/tcp_echo_server.py b/examples/tcp_echo_server.py index 02d7ca3b..031963de 100644 --- a/examples/tcp_echo_server.py +++ b/examples/tcp_echo_server.py @@ -12,7 +12,7 @@ import time import socket import selectors -from typing import Dict +from typing import Dict, Any from proxy.core.acceptor import AcceptorPool, Work from proxy.common.flags import Flags @@ -29,6 +29,10 @@ class EchoServerHandler(Work): intialize, is_inactive and shutdown method. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + print('Connection accepted from {0}'.format(self.client.addr)) + def get_events(self) -> Dict[socket.socket, int]: # We always want to read from client # Register for EVENT_READ events @@ -45,12 +49,21 @@ class EchoServerHandler(Work): writables: Writables) -> bool: """Return True to shutdown work.""" if self.client.connection in readables: - data = self.client.recv() - if data is None: - # Client closed connection, signal shutdown + try: + data = self.client.recv() + if data is None: + # Client closed connection, signal shutdown + print( + 'Connection closed by client {0}'.format( + self.client.addr)) + return True + # Echo data back to client + self.client.queue(data) + except ConnectionResetError: + print( + 'Connection reset by client {0}'.format( + self.client.addr)) return True - # Queue data back to client - self.client.queue(data) if self.client.connection in writables: self.client.flush() @@ -61,7 +74,7 @@ class EchoServerHandler(Work): def main() -> None: # This example requires `threadless=True` pool = AcceptorPool( - flags=Flags(num_workers=1, threadless=True), + flags=Flags(port=12345, num_workers=1, threadless=True), work_klass=EchoServerHandler) try: pool.setup() diff --git a/examples/websocket_client.py b/examples/websocket_client.py index 5509d10d..c87ed3e4 100644 --- a/examples/websocket_client.py +++ b/examples/websocket_client.py @@ -22,8 +22,10 @@ num_echos = 10 def on_message(frame: WebsocketFrame) -> None: """WebsocketClient on_message callback.""" global client, num_echos, last_dispatch_time - print('Received %r after %d millisec' % (frame.data, (time.time() - last_dispatch_time) * 1000)) - assert(frame.data == b'hello' and frame.opcode == websocketOpcodes.TEXT_FRAME) + print('Received %r after %d millisec' % + (frame.data, (time.time() - last_dispatch_time) * 1000)) + assert(frame.data == b'hello' and frame.opcode == + websocketOpcodes.TEXT_FRAME) if num_echos > 0: client.queue(static_frame) last_dispatch_time = time.time() @@ -34,7 +36,11 @@ def on_message(frame: WebsocketFrame) -> None: if __name__ == '__main__': # Constructor establishes socket connection - client = WebsocketClient(b'echo.websocket.org', 80, b'/', on_message=on_message) + client = WebsocketClient( + b'echo.websocket.org', + 80, + b'/', + on_message=on_message) # Perform handshake client.handshake() # Queue some data for client diff --git a/proxy/core/connection/client.py b/proxy/core/connection/client.py index 28995a58..62597a10 100644 --- a/proxy/core/connection/client.py +++ b/proxy/core/connection/client.py @@ -30,3 +30,15 @@ class TcpClientConnection(TcpConnection): if self._conn is None: raise TcpConnectionUninitializedException() return self._conn + + def wrap(self, keyfile: str, certfile: str) -> None: + self.connection.setblocking(True) + self.flush() + self._conn = ssl.wrap_socket( + self.connection, + server_side=True, + # ca_certs=self.flags.ca_cert_file, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.PROTOCOL_TLS) + self.connection.setblocking(False) diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index cbb9806a..c5636e6a 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -8,8 +8,8 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket import ssl +import socket from typing import Optional, Union, Tuple from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException @@ -34,3 +34,14 @@ class TcpServerConnection(TcpConnection): if self._conn is not None: return self._conn = new_socket_connection(self.addr) + + def wrap(self, hostname: str, ca_file: Optional[str]) -> None: + ctx = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH, cafile=ca_file) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 + ctx.check_hostname = True + self.connection.setblocking(True) + self._conn = ctx.wrap_socket( + self.connection, + server_hostname=hostname) + self.connection.setblocking(False) diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 3fad47e7..c5c9a75f 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -453,31 +453,15 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): def wrap_server(self) -> None: assert self.server is not None assert isinstance(self.server.connection, socket.socket) - ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 - ctx.check_hostname = True - self.server.connection.setblocking(True) - self.server._conn = ctx.wrap_socket( - self.server.connection, - server_hostname=text_(self.request.host)) - self.server.connection.setblocking(False) + self.server.wrap(text_(self.request.host), self.flags.ca_file) + assert isinstance(self.server.connection, ssl.SSLSocket) def wrap_client(self) -> None: - assert self.server is not None + assert self.server is not None and self.flags.ca_signing_key_file is not None assert isinstance(self.server.connection, ssl.SSLSocket) generated_cert = self.generate_upstream_certificate( cast(Dict[str, Any], self.server.connection.getpeercert())) - self.client.connection.setblocking(True) - self.client.flush() - self.client._conn = ssl.wrap_socket( - self.client.connection, - server_side=True, - # ca_certs=self.flags.ca_cert_file, - certfile=generated_cert, - keyfile=self.flags.ca_signing_key_file, - ssl_version=ssl.PROTOCOL_TLS) - self.client.connection.setblocking(False) + self.client.wrap(self.flags.ca_signing_key_file, generated_cert) logger.debug( 'TLS interception using %s', generated_cert) diff --git a/proxy/http/websocket/client.py b/proxy/http/websocket/client.py index 16e440de..716d0fae 100644 --- a/proxy/http/websocket/client.py +++ b/proxy/http/websocket/client.py @@ -52,7 +52,11 @@ class WebsocketClient(TcpConnection): def upgrade(self) -> None: key = base64.b64encode(secrets.token_bytes(16)) - self.sock.send(build_websocket_handshake_request(key, url=self.path, host=self.hostname)) + self.sock.send( + build_websocket_handshake_request( + key, + url=self.path, + host=self.hostname)) response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE)) accept = response.header(b'Sec-Websocket-Accept') diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index 2790f8e5..c904e69c 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -17,7 +17,7 @@ import selectors from typing import Any from unittest import mock -from proxy.core.connection import TcpClientConnection +from proxy.core.connection import TcpClientConnection, TcpServerConnection from proxy.http.handler import HttpProtocolHandler from proxy.http.proxy import HttpProxyPlugin from proxy.http.methods import httpMethods @@ -71,6 +71,11 @@ class TestHttpProxyTlsInterception(unittest.TestCase): return ssl_connection return plain_connection + # Do not mock the original wrap method + self.mock_server_conn.return_value.wrap.side_effect = \ + lambda x, y: TcpServerConnection.wrap( + self.mock_server_conn.return_value, x, y) + type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 0c1ec0a4..f3289e0d 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -19,7 +19,7 @@ from typing import Any, cast from proxy.common.utils import bytes_ from proxy.common.flags import Flags from proxy.common.utils import build_http_request, build_http_response -from proxy.core.connection import TcpClientConnection +from proxy.core.connection import TcpClientConnection, TcpServerConnection from proxy.http.codes import httpStatusCodes from proxy.http.methods import httpMethods from proxy.http.handler import HttpProtocolHandler @@ -98,6 +98,10 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): return self.server_ssl_connection return self._conn + # Do not mock the original wrap method + self.server.wrap.side_effect = \ + lambda x, y: TcpServerConnection.wrap(self.server, x, y) + self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mock.PropertyMock(side_effect=closed) type(