diff --git a/examples/base_echo_server.py b/examples/base_server.py similarity index 77% rename from examples/base_echo_server.py rename to examples/base_server.py index 70e875bd..19b3c7b6 100644 --- a/examples/base_echo_server.py +++ b/examples/base_server.py @@ -8,6 +8,7 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ +from abc import abstractmethod import socket import selectors @@ -17,22 +18,25 @@ from proxy.core.acceptor import Work from proxy.common.types import Readables, Writables -class BaseEchoServerHandler(Work): - """BaseEchoServerHandler implements Work interface. +class BaseServerHandler(Work): + """BaseServerHandler implements Work interface. - An instance of EchoServerHandler is created for each client - connection. EchoServerHandler lifecycle is controlled by - Threadless core using asyncio. Implementation must provide - get_events and handle_events method. Optionally, also implement - intialize, is_inactive and shutdown method. + An instance of BaseServerHandler is created for each client + connection. BaseServerHandler lifecycle is controlled by + Threadless core using asyncio. + + Implementation must provide: + a) handle_data(data: memoryview) + c) (optionally) intialize, is_inactive and shutdown methods """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) print('Connection accepted from {0}'.format(self.client.addr)) - def initialize(self) -> None: - pass + @abstractmethod + def handle_data(self, data: memoryview) -> None: + pass # pragma: no cover def get_events(self) -> Dict[socket.socket, int]: # We always want to read from client @@ -58,8 +62,7 @@ class BaseEchoServerHandler(Work): 'Connection closed by client {0}'.format( self.client.addr)) return True - # Echo data back to client - self.client.queue(data) + self.handle_data(data) except ConnectionResetError: print( 'Connection reset by client {0}'.format( diff --git a/examples/connect_tunnel.py b/examples/connect_tunnel.py new file mode 100644 index 00000000..2b1aeefa --- /dev/null +++ b/examples/connect_tunnel.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import socket +import selectors +from typing import Any, Optional, Dict + +from proxy.core.acceptor import AcceptorPool +from proxy.core.connection import TcpServerConnection +from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates +from proxy.http.codes import httpStatusCodes +from proxy.http.methods import httpMethods +from proxy.common.types import Readables, Writables +from proxy.common.utils import build_http_response, text_ +from proxy.common.flags import Flags + +from examples.base_server import BaseServerHandler + + +class ConnectTunnelHandler(BaseServerHandler): # type: ignore + """A http CONNECT tunnel server.""" + + PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview(build_http_response( + httpStatusCodes.OK, + reason=b'Connection established' + )) + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.request = HttpParser(httpParserTypes.REQUEST_PARSER) + self.upstream: Optional[TcpServerConnection] = None + + def initialize(self) -> None: + self.client.connection.setblocking(False) + + def shutdown(self) -> None: + if self.upstream: + print('Connection closed with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port)) + self.upstream.close() + super().shutdown() + + def handle_data(self, data: memoryview) -> None: + # Queue for upstream if connection has been established + if self.upstream and self.upstream._conn is not None: + self.upstream.queue(data) + return + + # Parse client request + self.request.parse(data) + + # Drop the request if not a CONNECT request + if self.request.method != httpMethods.CONNECT: + pass + + # CONNECT requests are short and we need not worry about + # receiving partial request bodies here. + assert self.request.state == httpParserStates.COMPLETE + + # Establish connection with upstream + assert self.request.host and self.request.port + self.upstream = TcpServerConnection( + text_(self.request.host), self.request.port) + self.upstream.connect() + print('Connection established with upstream {0}:{1}'.format( + text_(self.request.host), self.request.port)) + + # Queue tunnel established response to client + self.client.queue( + ConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + + def get_events(self) -> Dict[socket.socket, int]: + # Get default client events + ev: Dict[socket.socket, int] = super().get_events() + # Read from server if we are connected + if self.upstream and self.upstream._conn is not None: + ev[self.upstream.connection] = selectors.EVENT_READ + # If there is pending buffer for server + # also register for EVENT_WRITE events + if self.upstream and self.upstream.has_buffer(): + if self.upstream.connection in ev: + ev[self.upstream.connection] |= selectors.EVENT_WRITE + else: + ev[self.upstream.connection] = selectors.EVENT_WRITE + return ev + + def handle_events( + self, + readables: Readables, + writables: Writables) -> bool: + # Handle client events + do_shutdown: bool = super().handle_events(readables, writables) + if do_shutdown: + return do_shutdown + # Handle server events + if self.upstream and self.upstream.connection in readables: + data = self.upstream.recv() + if data is None: + # Server closed connection + print('Connection closed by server') + return True + # tunnel data to client + self.client.queue(data) + if self.upstream and self.upstream.connection in writables: + self.upstream.flush() + return False + + +def main() -> None: + # This example requires `threadless=True` + pool = AcceptorPool( + flags=Flags(port=12345, num_workers=1, threadless=True), + work_klass=ConnectTunnelHandler) + try: + pool.setup() + while True: + time.sleep(1) + except KeyboardInterrupt: + pass + finally: + pool.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/ssl_echo_server.py b/examples/ssl_echo_server.py index 1bb401bf..5e660dfa 100644 --- a/examples/ssl_echo_server.py +++ b/examples/ssl_echo_server.py @@ -15,10 +15,10 @@ from proxy.core.connection import TcpClientConnection from proxy.common.flags import Flags from proxy.common.utils import wrap_socket -from examples.base_echo_server import BaseEchoServerHandler +from examples.base_server import BaseServerHandler -class EchoSSLServerHandler(BaseEchoServerHandler): # type: ignore +class EchoSSLServerHandler(BaseServerHandler): # type: ignore """Wraps client socket during initialization.""" def initialize(self) -> None: @@ -34,6 +34,10 @@ class EchoSSLServerHandler(BaseEchoServerHandler): # type: ignore self.client = TcpClientConnection( conn=conn, addr=self.client.addr) # type: ignore + def handle_data(self, data: memoryview) -> None: + # echo back to client + self.client.queue(data) + def main() -> None: # This example requires `threadless=True` @@ -49,6 +53,8 @@ def main() -> None: pool.setup() while True: time.sleep(1) + except KeyboardInterrupt: + pass finally: pool.shutdown() diff --git a/examples/tcp_echo_server.py b/examples/tcp_echo_server.py index fa5da58a..f5b0aa50 100644 --- a/examples/tcp_echo_server.py +++ b/examples/tcp_echo_server.py @@ -13,15 +13,19 @@ import time from proxy.core.acceptor import AcceptorPool from proxy.common.flags import Flags -from examples.base_echo_server import BaseEchoServerHandler +from examples.base_server import BaseServerHandler -class EchoServerHandler(BaseEchoServerHandler): # type: ignore +class EchoServerHandler(BaseServerHandler): # type: ignore """Sets client socket to non-blocking during initialization.""" def initialize(self) -> None: self.client.connection.setblocking(False) + def handle_data(self, data: memoryview) -> None: + # echo back to client + self.client.queue(data) + def main() -> None: # This example requires `threadless=True` @@ -32,6 +36,8 @@ def main() -> None: pool.setup() while True: time.sleep(1) + except KeyboardInterrupt: + pass finally: pool.shutdown() diff --git a/proxy/http/parser.py b/proxy/http/parser.py index d40885ca..638959ac 100644 --- a/proxy/http/parser.py +++ b/proxy/http/parser.py @@ -110,7 +110,7 @@ class HttpParser: # For CONNECT requests, request line contains # upstream_host:upstream_port which is not complaint # with urlsplit, which expects a fully qualified url. - if self.method == b'CONNECT': + if self.method == httpMethods.CONNECT: url = b'https://' + url self.url = urlparse.urlsplit(url) self.set_line_attributes() diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index ef8f5ab1..d0297ea8 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -408,7 +408,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): server_host, server_port = self.server.addr if self.server else ( None, None) connection_time_ms = (time.time() - self.start_time) * 1000 - if self.request.method == b'CONNECT': + if self.request.method == httpMethods.CONNECT: logger.info( '%s:%s - %s %s:%s - %s bytes - %.2f ms' % (self.client.addr[0], diff --git a/tests/http/test_http_parser.py b/tests/http/test_http_parser.py index e8d78cbe..253f2b7f 100644 --- a/tests/http/test_http_parser.py +++ b/tests/http/test_http_parser.py @@ -308,7 +308,7 @@ class TestHttpParser(unittest.TestCase): See https://github.com/abhinavsingh/py/issues/5 for details. """ self.parser.parse(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n') - self.assertEqual(self.parser.method, b'CONNECT') + self.assertEqual(self.parser.method, httpMethods.CONNECT) self.assertEqual(self.parser.version, b'HTTP/1.0') self.assertEqual(self.parser.state, httpParserStates.COMPLETE)