More examples (#444)

* Refactor into BaseServerHandler and BaseEchoServerHandler classes

* Add connect tunnel example
This commit is contained in:
Abhinav Singh 2020-10-06 22:27:19 +05:30 committed by GitHub
parent 29e2a35091
commit 1038bb841d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 166 additions and 18 deletions

View File

@ -8,6 +8,7 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from abc import abstractmethod
import socket
import selectors
@ -17,22 +18,25 @@ from proxy.core.acceptor import Work
from proxy.common.types import Readables, Writables
class BaseEchoServerHandler(Work):
"""BaseEchoServerHandler implements Work interface.
class BaseServerHandler(Work):
"""BaseServerHandler implements Work interface.
An instance of EchoServerHandler is created for each client
connection. EchoServerHandler lifecycle is controlled by
Threadless core using asyncio. Implementation must provide
get_events and handle_events method. Optionally, also implement
intialize, is_inactive and shutdown method.
An instance of BaseServerHandler is created for each client
connection. BaseServerHandler lifecycle is controlled by
Threadless core using asyncio.
Implementation must provide:
a) handle_data(data: memoryview)
c) (optionally) intialize, is_inactive and shutdown methods
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
print('Connection accepted from {0}'.format(self.client.addr))
def initialize(self) -> None:
pass
@abstractmethod
def handle_data(self, data: memoryview) -> None:
pass # pragma: no cover
def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
@ -58,8 +62,7 @@ class BaseEchoServerHandler(Work):
'Connection closed by client {0}'.format(
self.client.addr))
return True
# Echo data back to client
self.client.queue(data)
self.handle_data(data)
except ConnectionResetError:
print(
'Connection reset by client {0}'.format(

133
examples/connect_tunnel.py Normal file
View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import time
import socket
import selectors
from typing import Any, Optional, Dict
from proxy.core.acceptor import AcceptorPool
from proxy.core.connection import TcpServerConnection
from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates
from proxy.http.codes import httpStatusCodes
from proxy.http.methods import httpMethods
from proxy.common.types import Readables, Writables
from proxy.common.utils import build_http_response, text_
from proxy.common.flags import Flags
from examples.base_server import BaseServerHandler
class ConnectTunnelHandler(BaseServerHandler): # type: ignore
"""A http CONNECT tunnel server."""
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = memoryview(build_http_response(
httpStatusCodes.OK,
reason=b'Connection established'
))
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.request = HttpParser(httpParserTypes.REQUEST_PARSER)
self.upstream: Optional[TcpServerConnection] = None
def initialize(self) -> None:
self.client.connection.setblocking(False)
def shutdown(self) -> None:
if self.upstream:
print('Connection closed with upstream {0}:{1}'.format(
text_(self.request.host), self.request.port))
self.upstream.close()
super().shutdown()
def handle_data(self, data: memoryview) -> None:
# Queue for upstream if connection has been established
if self.upstream and self.upstream._conn is not None:
self.upstream.queue(data)
return
# Parse client request
self.request.parse(data)
# Drop the request if not a CONNECT request
if self.request.method != httpMethods.CONNECT:
pass
# CONNECT requests are short and we need not worry about
# receiving partial request bodies here.
assert self.request.state == httpParserStates.COMPLETE
# Establish connection with upstream
assert self.request.host and self.request.port
self.upstream = TcpServerConnection(
text_(self.request.host), self.request.port)
self.upstream.connect()
print('Connection established with upstream {0}:{1}'.format(
text_(self.request.host), self.request.port))
# Queue tunnel established response to client
self.client.queue(
ConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
def get_events(self) -> Dict[socket.socket, int]:
# Get default client events
ev: Dict[socket.socket, int] = super().get_events()
# Read from server if we are connected
if self.upstream and self.upstream._conn is not None:
ev[self.upstream.connection] = selectors.EVENT_READ
# If there is pending buffer for server
# also register for EVENT_WRITE events
if self.upstream and self.upstream.has_buffer():
if self.upstream.connection in ev:
ev[self.upstream.connection] |= selectors.EVENT_WRITE
else:
ev[self.upstream.connection] = selectors.EVENT_WRITE
return ev
def handle_events(
self,
readables: Readables,
writables: Writables) -> bool:
# Handle client events
do_shutdown: bool = super().handle_events(readables, writables)
if do_shutdown:
return do_shutdown
# Handle server events
if self.upstream and self.upstream.connection in readables:
data = self.upstream.recv()
if data is None:
# Server closed connection
print('Connection closed by server')
return True
# tunnel data to client
self.client.queue(data)
if self.upstream and self.upstream.connection in writables:
self.upstream.flush()
return False
def main() -> None:
# This example requires `threadless=True`
pool = AcceptorPool(
flags=Flags(port=12345, num_workers=1, threadless=True),
work_klass=ConnectTunnelHandler)
try:
pool.setup()
while True:
time.sleep(1)
except KeyboardInterrupt:
pass
finally:
pool.shutdown()
if __name__ == '__main__':
main()

View File

@ -15,10 +15,10 @@ from proxy.core.connection import TcpClientConnection
from proxy.common.flags import Flags
from proxy.common.utils import wrap_socket
from examples.base_echo_server import BaseEchoServerHandler
from examples.base_server import BaseServerHandler
class EchoSSLServerHandler(BaseEchoServerHandler): # type: ignore
class EchoSSLServerHandler(BaseServerHandler): # type: ignore
"""Wraps client socket during initialization."""
def initialize(self) -> None:
@ -34,6 +34,10 @@ class EchoSSLServerHandler(BaseEchoServerHandler): # type: ignore
self.client = TcpClientConnection(
conn=conn, addr=self.client.addr) # type: ignore
def handle_data(self, data: memoryview) -> None:
# echo back to client
self.client.queue(data)
def main() -> None:
# This example requires `threadless=True`
@ -49,6 +53,8 @@ def main() -> None:
pool.setup()
while True:
time.sleep(1)
except KeyboardInterrupt:
pass
finally:
pool.shutdown()

View File

@ -13,15 +13,19 @@ import time
from proxy.core.acceptor import AcceptorPool
from proxy.common.flags import Flags
from examples.base_echo_server import BaseEchoServerHandler
from examples.base_server import BaseServerHandler
class EchoServerHandler(BaseEchoServerHandler): # type: ignore
class EchoServerHandler(BaseServerHandler): # type: ignore
"""Sets client socket to non-blocking during initialization."""
def initialize(self) -> None:
self.client.connection.setblocking(False)
def handle_data(self, data: memoryview) -> None:
# echo back to client
self.client.queue(data)
def main() -> None:
# This example requires `threadless=True`
@ -32,6 +36,8 @@ def main() -> None:
pool.setup()
while True:
time.sleep(1)
except KeyboardInterrupt:
pass
finally:
pool.shutdown()

View File

@ -110,7 +110,7 @@ class HttpParser:
# For CONNECT requests, request line contains
# upstream_host:upstream_port which is not complaint
# with urlsplit, which expects a fully qualified url.
if self.method == b'CONNECT':
if self.method == httpMethods.CONNECT:
url = b'https://' + url
self.url = urlparse.urlsplit(url)
self.set_line_attributes()

View File

@ -408,7 +408,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
server_host, server_port = self.server.addr if self.server else (
None, None)
connection_time_ms = (time.time() - self.start_time) * 1000
if self.request.method == b'CONNECT':
if self.request.method == httpMethods.CONNECT:
logger.info(
'%s:%s - %s %s:%s - %s bytes - %.2f ms' %
(self.client.addr[0],

View File

@ -308,7 +308,7 @@ class TestHttpParser(unittest.TestCase):
See https://github.com/abhinavsingh/py/issues/5 for details.
"""
self.parser.parse(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n')
self.assertEqual(self.parser.method, b'CONNECT')
self.assertEqual(self.parser.method, httpMethods.CONNECT)
self.assertEqual(self.parser.version, b'HTTP/1.0')
self.assertEqual(self.parser.state, httpParserStates.COMPLETE)