Defer SSL Wrap (#100)
* Always deduce address family as we have a structure * Convert TcpConnection into an ABC. * Fix tests * Rename MultiCoreRequestDispatcher to WorkerPool * Consistent naming refactor * Make WorkerPool independent of protocol config object * Fix tests * Proper fix for test_on_client_connection_called_on_teardown * TLS interception * better logging
This commit is contained in:
parent
e38531b4c3
commit
716f211a2b
30
README.md
30
README.md
|
@ -42,7 +42,7 @@ Table of Contents
|
|||
* [ManInTheMiddlePlugin](#maninthemiddleplugin)
|
||||
* [Plugin Ordering](#plugin-ordering)
|
||||
* [End-to-End Encryption](#end-to-end-encryption)
|
||||
* [TLS Encryption](#tls-interception)
|
||||
* [TLS Interception](#tls-interception)
|
||||
* [Plugin Developer and Contributor Guide](#plugin-developer-and-contributor-guide)
|
||||
* [Everything is a plugin](#everything-is-a-plugin)
|
||||
* [proxy.py Internals](#proxypy-internals)
|
||||
|
@ -230,7 +230,7 @@ Above `418 I'm a tea pot` is sent by our plugin.
|
|||
Verify the same by inspecting logs for `proxy.py`:
|
||||
|
||||
```
|
||||
2019-09-24 19:21:37,893 - ERROR - pid:50074 - handle_readables:1347 - HttpProtocolException type raised
|
||||
2019-09-24 19:21:37,893 - ERROR - pid:50074 - handle_readables:1347 - ProtocolException type raised
|
||||
Traceback (most recent call last):
|
||||
... [redacted] ...
|
||||
2019-09-24 19:21:37,897 - INFO - pid:50074 - access_log:1157 - ::1:49911 - GET None:None/ - None None - 0 bytes
|
||||
|
@ -485,38 +485,38 @@ As you might have guessed by now, in `proxy.py` everything is a plugin.
|
|||
Example, [FilterByUpstreamHostPlugin](#filterbyupstreamhostplugin).
|
||||
|
||||
- We also enabled inbuilt web server using `--enable-web-server`.
|
||||
Inbuilt web server implements `HttpProtocolBasePlugin` plugin.
|
||||
See documentation of [HttpProtocolBasePlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L793-L850)
|
||||
for available lifecycle hooks. Use `HttpProtocolBasePlugin` to add
|
||||
Inbuilt web server implements `ProtocolHandlerPlugin` plugin.
|
||||
See documentation of [ProtocolHandlerPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L793-L850)
|
||||
for available lifecycle hooks. Use `ProtocolHandlerPlugin` to add
|
||||
new features for http(s) clients. Example,
|
||||
[HttpWebServerPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1185-L1260).
|
||||
|
||||
- There also is a `--disable-http-proxy` flag. It disables inbuilt proxy server.
|
||||
Use this flag with `--enable-web-server` flag to run `proxy.py` as a programmable
|
||||
http(s) server. [HttpProxyPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L941-L1182)
|
||||
also implements `HttpProtocolBasePlugin`.
|
||||
also implements `ProtocolHandlerPlugin`.
|
||||
|
||||
## proxy.py Internals
|
||||
|
||||
- [HttpProtocolHandler](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1263-L1440)
|
||||
- [ProtocolHandler](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1263-L1440)
|
||||
thread is started with the accepted [TcpClientConnection](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L230-L237).
|
||||
`HttpProtocolHandler` is responsible for parsing incoming client request and invoking
|
||||
`HttpProtocolBasePlugin` lifecycle hooks.
|
||||
`ProtocolHandler` is responsible for parsing incoming client request and invoking
|
||||
`ProtocolHandlerPlugin` lifecycle hooks.
|
||||
|
||||
- `HttpProxyPlugin` which implements `HttpProtocolBasePlugin` also has its own plugin
|
||||
- `HttpProxyPlugin` which implements `ProtocolHandlerPlugin` also has its own plugin
|
||||
mechanism. Its responsibility is to establish connection between client and
|
||||
upstream [TcpServerConnection](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L204-L227)
|
||||
and invoke `HttpProxyBasePlugin` lifecycle hooks.
|
||||
|
||||
- `HttpProtocolHandler` threads are started by [Worker](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472)
|
||||
- `ProtocolHandler` threads are started by [Worker](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472)
|
||||
processes.
|
||||
|
||||
- `--num-workers` `Worker` processes are started by
|
||||
[MultiCoreRequestDispatcher](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L368-L421)
|
||||
on start-up. `Worker` processes receives `TcpClientConnection` over a pipe from `MultiCoreRequestDispatcher`.
|
||||
[WorkerPool](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L368-L421)
|
||||
on start-up. `Worker` processes receives `TcpClientConnection` over a pipe from `WorkerPool`.
|
||||
|
||||
- `MultiCoreRequestDispatcher` implements [TcpServer](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L240-L302)
|
||||
abstract class. `TcpServer` accepts `TcpClientConnection`. `MultiCoreRequestDispatcher`
|
||||
- `WorkerPool` implements [TcpServer](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L240-L302)
|
||||
abstract class. `TcpServer` accepts `TcpClientConnection`. `WorkerPool`
|
||||
ensures full utilization of available CPU cores, for which it dispatches
|
||||
accepted `TcpClientConnection` to `Worker` processes in a round-robin fashion.
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class CacheResponsesPlugin(proxy.HttpProxyBasePlugin):
|
|||
|
||||
CACHE_DIR = tempfile.gettempdir()
|
||||
|
||||
def __init__(self, config: proxy.HttpProtocolConfig, client: proxy.TcpClientConnection,
|
||||
def __init__(self, config: proxy.ProtocolConfig, client: proxy.TcpClientConnection,
|
||||
request: proxy.HttpParser) -> None:
|
||||
super().__init__(config, client, request)
|
||||
self.cache_file_path: Optional[str] = None
|
||||
|
|
427
proxy.py
427
proxy.py
|
@ -150,26 +150,31 @@ HttpProtocolTypes = NamedTuple('HttpProtocolTypes', [
|
|||
httpProtocolTypes = HttpProtocolTypes(1, 2)
|
||||
|
||||
|
||||
class TcpConnection:
|
||||
class TcpConnectionUninitializedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TcpConnection(ABC):
|
||||
"""TCP server/client connection abstraction."""
|
||||
|
||||
def __init__(self, tag: int):
|
||||
self.conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
|
||||
self.buffer: bytes = b''
|
||||
self.closed: bool = False
|
||||
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||||
"""Must return the socket connection to use in this class."""
|
||||
raise TcpConnectionUninitializedException()
|
||||
|
||||
def send(self, data: bytes) -> int:
|
||||
"""Users must handle BrokenPipeError exceptions"""
|
||||
if not self.conn:
|
||||
raise KeyError('conn is None')
|
||||
return self.conn.send(data)
|
||||
return self.connection.send(data)
|
||||
|
||||
def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]:
|
||||
if not self.conn:
|
||||
raise KeyError('conn is None')
|
||||
try:
|
||||
data: bytes = self.conn.recv(buffer_size)
|
||||
data: bytes = self.connection.recv(buffer_size)
|
||||
if len(data) > 0:
|
||||
logger.debug(
|
||||
'received %d bytes from %s' %
|
||||
|
@ -181,14 +186,12 @@ class TcpConnection:
|
|||
else:
|
||||
logger.exception(
|
||||
'Exception while receiving from connection %s %r with reason %r' %
|
||||
(self.tag, self.conn, e))
|
||||
(self.tag, self.connection, e))
|
||||
return None
|
||||
|
||||
def close(self) -> bool:
|
||||
if not self.conn:
|
||||
raise KeyError('conn is None')
|
||||
if not self.closed:
|
||||
self.conn.close()
|
||||
self.connection.close()
|
||||
self.closed = True
|
||||
return self.closed
|
||||
|
||||
|
@ -210,31 +213,37 @@ class TcpConnection:
|
|||
|
||||
|
||||
class TcpServerConnection(TcpConnection):
|
||||
"""Establishes connection to destination server."""
|
||||
"""Establishes connection to upstream server."""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
super().__init__(tcpConnectionTypes.SERVER)
|
||||
self.addr: Tuple[str, int] = (host, int(port))
|
||||
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.conn:
|
||||
self.close()
|
||||
@property
|
||||
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||||
if self._conn is None:
|
||||
raise TcpConnectionUninitializedException()
|
||||
return self._conn
|
||||
|
||||
def connect(self) -> None:
|
||||
if self._conn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(text_(self.addr[0]))
|
||||
if ip.version == 4:
|
||||
self.conn = socket.socket(
|
||||
self._conn = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
self.conn.connect((self.addr[0], self.addr[1]))
|
||||
self._conn.connect((self.addr[0], self.addr[1]))
|
||||
else:
|
||||
self.conn = socket.socket(
|
||||
self._conn = socket.socket(
|
||||
socket.AF_INET6, socket.SOCK_STREAM, 0)
|
||||
self.conn.connect((self.addr[0], self.addr[1], 0, 0))
|
||||
self._conn.connect((self.addr[0], self.addr[1], 0, 0))
|
||||
except ValueError:
|
||||
# Not a valid IP address, most likely its a domain name,
|
||||
# try to establish dual stack IPv4/IPv6 connection.
|
||||
self.conn = socket.create_connection((self.addr[0], self.addr[1]))
|
||||
self._conn = socket.create_connection((self.addr[0], self.addr[1]))
|
||||
|
||||
|
||||
class TcpClientConnection(TcpConnection):
|
||||
|
@ -243,9 +252,15 @@ class TcpClientConnection(TcpConnection):
|
|||
def __init__(self, conn: Union[ssl.SSLSocket,
|
||||
socket.socket], addr: Tuple[str, int]):
|
||||
super().__init__(tcpConnectionTypes.CLIENT)
|
||||
self.conn: Union[ssl.SSLSocket, socket.socket] = conn
|
||||
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
|
||||
self.addr: Tuple[str, int] = addr
|
||||
|
||||
@property
|
||||
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||||
if self._conn is None:
|
||||
raise TcpConnectionUninitializedException()
|
||||
return self._conn
|
||||
|
||||
|
||||
class TcpServer(ABC):
|
||||
"""TcpServer server implementation.
|
||||
|
@ -259,13 +274,12 @@ class TcpServer(ABC):
|
|||
hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME,
|
||||
port: int = DEFAULT_PORT,
|
||||
backlog: int = DEFAULT_BACKLOG,
|
||||
family: socket.AddressFamily = socket.AF_INET6):
|
||||
backlog: int = DEFAULT_BACKLOG):
|
||||
self.port: int = port
|
||||
self.backlog: int = backlog
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.running: bool = False
|
||||
self.family: socket.AddressFamily = family
|
||||
self.family: socket.AddressFamily = socket.AF_INET6 if hostname.version == 6 else socket.AF_INET
|
||||
self.hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = hostname
|
||||
|
||||
|
@ -315,100 +329,40 @@ class TcpServer(ABC):
|
|||
self.socket.close()
|
||||
|
||||
|
||||
class HttpProtocolConfig:
|
||||
"""Holds various configuration values applicable to HttpProtocolHandler.
|
||||
|
||||
This config class helps us avoid passing around bunch of key/value pairs across methods.
|
||||
"""
|
||||
|
||||
ROOT_DATA_DIR_NAME = '.proxy.py'
|
||||
GENERATED_CERTS_DIR_NAME = 'certificates'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_code: Optional[bytes] = DEFAULT_BASIC_AUTH,
|
||||
server_recvbuf_size: int = DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
client_recvbuf_size: int = DEFAULT_CLIENT_RECVBUF_SIZE,
|
||||
pac_file: Optional[str] = DEFAULT_PAC_FILE,
|
||||
pac_file_url_path: Optional[bytes] = DEFAULT_PAC_FILE_URL_PATH,
|
||||
plugins: Optional[Dict[bytes, List[type]]] = None,
|
||||
disable_headers: Optional[List[bytes]] = None,
|
||||
certfile: Optional[str] = None,
|
||||
keyfile: Optional[str] = None,
|
||||
ca_cert_dir: Optional[str] = None,
|
||||
ca_key_file: Optional[str] = None,
|
||||
ca_cert_file: Optional[str] = None,
|
||||
ca_signing_key_file: Optional[str] = None,
|
||||
num_workers: int = 0,
|
||||
hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME,
|
||||
port: int = DEFAULT_PORT,
|
||||
backlog: int = DEFAULT_BACKLOG) -> None:
|
||||
self.auth_code = auth_code
|
||||
self.server_recvbuf_size = server_recvbuf_size
|
||||
self.client_recvbuf_size = client_recvbuf_size
|
||||
self.pac_file = pac_file
|
||||
self.pac_file_url_path = pac_file_url_path
|
||||
if plugins is None:
|
||||
plugins = {}
|
||||
self.plugins: Dict[bytes, List[type]] = plugins
|
||||
if disable_headers is None:
|
||||
disable_headers = DEFAULT_DISABLE_HEADERS
|
||||
self.disable_headers = disable_headers
|
||||
self.certfile: Optional[str] = certfile
|
||||
self.keyfile: Optional[str] = keyfile
|
||||
self.ca_key_file: Optional[str] = ca_key_file
|
||||
self.ca_cert_file: Optional[str] = ca_cert_file
|
||||
self.ca_signing_key_file: Optional[str] = ca_signing_key_file
|
||||
self.num_workers: int = num_workers
|
||||
self.hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = hostname
|
||||
self.port: int = port
|
||||
self.backlog: int = backlog
|
||||
self.family: socket.AddressFamily = socket.AF_INET if hostname.version == 4 else socket.AF_INET6
|
||||
|
||||
self.proxy_py_data_dir = os.path.join(
|
||||
str(pathlib.Path.home()), self.ROOT_DATA_DIR_NAME)
|
||||
os.makedirs(self.proxy_py_data_dir, exist_ok=True)
|
||||
|
||||
self.ca_cert_dir: Optional[str] = ca_cert_dir
|
||||
if self.ca_cert_dir is None:
|
||||
self.ca_cert_dir = os.path.join(
|
||||
self.proxy_py_data_dir, self.GENERATED_CERTS_DIR_NAME)
|
||||
os.makedirs(self.ca_cert_dir, exist_ok=True)
|
||||
|
||||
|
||||
class MultiCoreRequestDispatcher(TcpServer):
|
||||
"""MultiCoreRequestDispatcher.
|
||||
class WorkerPool(TcpServer):
|
||||
"""WorkerPool.
|
||||
|
||||
Pre-spawns worker process to utilize all cores available on the system. Accepted `TcpClientConnection` is
|
||||
dispatched over a queue to workers. One of the worker picks up the work and starts a new thread to handle the
|
||||
client request.
|
||||
"""
|
||||
|
||||
def __init__(self, config: HttpProtocolConfig) -> None:
|
||||
super().__init__(
|
||||
hostname=config.hostname,
|
||||
port=config.port,
|
||||
backlog=config.backlog,
|
||||
family=config.family)
|
||||
def __init__(self, hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address],
|
||||
port: int, backlog: int, num_workers: int,
|
||||
work_klass: type, **kwargs: Any) -> None:
|
||||
super().__init__(hostname=hostname, port=port, backlog=backlog)
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
self.work_queues: List[Tuple[connection.Connection,
|
||||
connection.Connection]] = []
|
||||
self.current_worker_id = 0
|
||||
self.config: HttpProtocolConfig = config
|
||||
|
||||
self.work_klass = work_klass
|
||||
self.kwargs = kwargs
|
||||
|
||||
def setup(self) -> None:
|
||||
for worker_id in range(self.config.num_workers):
|
||||
for worker_id in range(self.num_workers):
|
||||
work_queue = multiprocessing.Pipe()
|
||||
|
||||
worker = Worker(work_queue[1], self.config)
|
||||
worker = Worker(work_queue[1], self.work_klass, **self.kwargs)
|
||||
worker.daemon = True
|
||||
worker.start()
|
||||
|
||||
self.workers.append(worker)
|
||||
self.work_queues.append(work_queue)
|
||||
logger.info('Started %d workers' % self.config.num_workers)
|
||||
logger.info('Started %d workers' % self.num_workers)
|
||||
|
||||
def handle(self, client: TcpClientConnection) -> None:
|
||||
# Dispatch in round robin fashion
|
||||
|
@ -418,15 +372,15 @@ class MultiCoreRequestDispatcher(TcpServer):
|
|||
self.current_worker_id)
|
||||
# Dispatch non-socket data first, followed by fileno using reduction
|
||||
work_queue[0].send((workerOperations.HTTP_PROTOCOL, client.addr))
|
||||
send_handle(work_queue[0], client.conn.fileno(),
|
||||
send_handle(work_queue[0], client.connection.fileno(),
|
||||
self.workers[self.current_worker_id].pid)
|
||||
# Close parent handler
|
||||
client.close()
|
||||
self.current_worker_id += 1
|
||||
self.current_worker_id %= self.config.num_workers
|
||||
self.current_worker_id %= self.num_workers
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logger.info('Shutting down %d workers' % self.config.num_workers)
|
||||
logger.info('Shutting down %d workers' % self.num_workers)
|
||||
for work_queue in self.work_queues:
|
||||
work_queue[0].send((workerOperations.SHUTDOWN, None))
|
||||
work_queue[0].close()
|
||||
|
@ -444,40 +398,20 @@ class Worker(multiprocessing.Process):
|
|||
def __init__(
|
||||
self,
|
||||
work_queue: connection.Connection,
|
||||
config: HttpProtocolConfig):
|
||||
work_klass: type,
|
||||
**kwargs: Any):
|
||||
super().__init__()
|
||||
self.work_queue: connection.Connection = work_queue
|
||||
self.config: HttpProtocolConfig = config
|
||||
self.work_klass = work_klass
|
||||
self.kwargs = kwargs
|
||||
|
||||
def run_once(self) -> bool:
|
||||
try:
|
||||
op, payload = self.work_queue.recv()
|
||||
if op == workerOperations.HTTP_PROTOCOL:
|
||||
fileno = recv_handle(self.work_queue)
|
||||
conn = socket.fromfd(
|
||||
fileno, family=self.config.family, type=socket.SOCK_STREAM)
|
||||
# TODO(abhinavsingh): Move handshake logic within
|
||||
# HttpProtocolHandler or should this go under TcpServer directly?
|
||||
# Rationale behind deferring ssl wrap is that plugins can custom wrap
|
||||
# sockets if necessary.
|
||||
if self.config.certfile and self.config.keyfile:
|
||||
try:
|
||||
ctx = ssl.create_default_context(
|
||||
ssl.Purpose.CLIENT_AUTH)
|
||||
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
ctx.load_cert_chain(
|
||||
certfile=self.config.certfile,
|
||||
keyfile=self.config.keyfile)
|
||||
conn = ctx.wrap_socket(conn, server_side=True)
|
||||
except OSError as e:
|
||||
logger.exception(
|
||||
'OSError encountered while ssl wrapping the client socket', exc_info=e)
|
||||
conn.close()
|
||||
return False
|
||||
proxy = HttpProtocolHandler(
|
||||
TcpClientConnection(conn=conn, addr=payload),
|
||||
config=self.config)
|
||||
proxy = self.work_klass(
|
||||
fileno=fileno, addr=payload, **self.kwargs)
|
||||
proxy.setDaemon(True)
|
||||
proxy.start()
|
||||
elif op == workerOperations.SHUTDOWN:
|
||||
|
@ -701,7 +635,7 @@ class HttpParser:
|
|||
else:
|
||||
self.version = line[0]
|
||||
self.code = line[1]
|
||||
self.reason = b' '.join(line[2:])
|
||||
self.reason = WHITESPACE.join(line[2:])
|
||||
self.set_host_port()
|
||||
|
||||
def process_header(self, raw: bytes) -> None:
|
||||
|
@ -789,7 +723,7 @@ class HttpParser:
|
|||
|
||||
##########################################################################
|
||||
# HttpParser was originally written to parse the incoming raw Http requests.
|
||||
# Since request / response objects passed to HttpProtocolBasePlugin methods
|
||||
# Since request / response objects passed to ProtocolHandlerPlugin methods
|
||||
# are also HttpParser objects, methods below were added to simplify developer API.
|
||||
##########################################################################
|
||||
|
||||
|
@ -798,18 +732,18 @@ class HttpParser:
|
|||
return True if self.host is not None else False
|
||||
|
||||
|
||||
class HttpProtocolException(Exception):
|
||||
"""Top level HttpProtocolException exception class.
|
||||
class ProtocolException(Exception):
|
||||
"""Top level ProtocolException exception class.
|
||||
|
||||
All exceptions raised during execution of Http request lifecycle MUST
|
||||
inherit HttpProtocolException base class. Implement response() method
|
||||
inherit ProtocolException base class. Implement response() method
|
||||
to optionally return custom response to client."""
|
||||
|
||||
def response(self, request: HttpParser) -> Optional[bytes]:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class HttpRequestRejected(HttpProtocolException):
|
||||
class HttpRequestRejected(ProtocolException):
|
||||
"""Generic exception that can be used to reject the client requests.
|
||||
|
||||
Connections can either be dropped/closed or optionally an
|
||||
|
@ -829,7 +763,7 @@ class HttpRequestRejected(HttpProtocolException):
|
|||
if self.status_code is not None:
|
||||
line = b'HTTP/1.1 ' + bytes_(str(self.status_code))
|
||||
if self.reason:
|
||||
line += b' ' + self.reason
|
||||
line += WHITESPACE + self.reason
|
||||
pkt.append(line)
|
||||
pkt.append(PROXY_AGENT_HEADER)
|
||||
if self.body:
|
||||
|
@ -842,17 +776,79 @@ class HttpRequestRejected(HttpProtocolException):
|
|||
return CRLF.join(pkt) if len(pkt) > 0 else None
|
||||
|
||||
|
||||
class HttpProtocolBasePlugin(ABC):
|
||||
"""Base HttpProtocolHandler Plugin class.
|
||||
class ProtocolConfig:
|
||||
"""Holds various configuration values applicable to ProtocolHandler.
|
||||
|
||||
This config class helps us avoid passing around bunch of key/value pairs across methods.
|
||||
"""
|
||||
|
||||
ROOT_DATA_DIR_NAME = '.proxy.py'
|
||||
GENERATED_CERTS_DIR_NAME = 'certificates'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_code: Optional[bytes] = DEFAULT_BASIC_AUTH,
|
||||
server_recvbuf_size: int = DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
client_recvbuf_size: int = DEFAULT_CLIENT_RECVBUF_SIZE,
|
||||
pac_file: Optional[str] = DEFAULT_PAC_FILE,
|
||||
pac_file_url_path: Optional[bytes] = DEFAULT_PAC_FILE_URL_PATH,
|
||||
plugins: Optional[Dict[bytes, List[type]]] = None,
|
||||
disable_headers: Optional[List[bytes]] = None,
|
||||
certfile: Optional[str] = None,
|
||||
keyfile: Optional[str] = None,
|
||||
ca_cert_dir: Optional[str] = None,
|
||||
ca_key_file: Optional[str] = None,
|
||||
ca_cert_file: Optional[str] = None,
|
||||
ca_signing_key_file: Optional[str] = None,
|
||||
num_workers: int = 0,
|
||||
hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME,
|
||||
port: int = DEFAULT_PORT,
|
||||
backlog: int = DEFAULT_BACKLOG) -> None:
|
||||
self.auth_code = auth_code
|
||||
self.server_recvbuf_size = server_recvbuf_size
|
||||
self.client_recvbuf_size = client_recvbuf_size
|
||||
self.pac_file = pac_file
|
||||
self.pac_file_url_path = pac_file_url_path
|
||||
if plugins is None:
|
||||
plugins = {}
|
||||
self.plugins: Dict[bytes, List[type]] = plugins
|
||||
if disable_headers is None:
|
||||
disable_headers = DEFAULT_DISABLE_HEADERS
|
||||
self.disable_headers = disable_headers
|
||||
self.certfile: Optional[str] = certfile
|
||||
self.keyfile: Optional[str] = keyfile
|
||||
self.ca_key_file: Optional[str] = ca_key_file
|
||||
self.ca_cert_file: Optional[str] = ca_cert_file
|
||||
self.ca_signing_key_file: Optional[str] = ca_signing_key_file
|
||||
self.num_workers: int = num_workers
|
||||
self.hostname: Union[ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address] = hostname
|
||||
self.port: int = port
|
||||
self.backlog: int = backlog
|
||||
|
||||
self.proxy_py_data_dir = os.path.join(
|
||||
str(pathlib.Path.home()), self.ROOT_DATA_DIR_NAME)
|
||||
os.makedirs(self.proxy_py_data_dir, exist_ok=True)
|
||||
|
||||
self.ca_cert_dir: Optional[str] = ca_cert_dir
|
||||
if self.ca_cert_dir is None:
|
||||
self.ca_cert_dir = os.path.join(
|
||||
self.proxy_py_data_dir, self.GENERATED_CERTS_DIR_NAME)
|
||||
os.makedirs(self.ca_cert_dir, exist_ok=True)
|
||||
|
||||
|
||||
class ProtocolHandlerPlugin(ABC):
|
||||
"""Base ProtocolHandler Plugin class.
|
||||
|
||||
Implement various lifecycle event methods to customize behavior."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection,
|
||||
request: HttpParser):
|
||||
self.config: HttpProtocolConfig = config
|
||||
self.config: ProtocolConfig = config
|
||||
self.client: TcpClientConnection = client
|
||||
self.request: HttpParser = request
|
||||
super().__init__()
|
||||
|
@ -902,7 +898,7 @@ class HttpProtocolBasePlugin(ABC):
|
|||
pass # pragma: no cover
|
||||
|
||||
|
||||
class ProxyConnectionFailed(HttpProtocolException):
|
||||
class ProxyConnectionFailed(ProtocolException):
|
||||
"""Exception raised when HttpProxyPlugin is unable to establish connection to upstream server."""
|
||||
|
||||
RESPONSE_PKT = HttpParser.build_response(
|
||||
|
@ -925,7 +921,7 @@ class ProxyConnectionFailed(HttpProtocolException):
|
|||
self.host, self.port, self.reason)
|
||||
|
||||
|
||||
class ProxyAuthenticationFailed(HttpProtocolException):
|
||||
class ProxyAuthenticationFailed(ProtocolException):
|
||||
"""Exception raised when Http Proxy auth is enabled and
|
||||
incoming request doesn't present necessary credentials."""
|
||||
|
||||
|
@ -947,7 +943,7 @@ class HttpProxyBasePlugin(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection,
|
||||
request: HttpParser):
|
||||
self.config = config
|
||||
|
@ -987,8 +983,8 @@ class HttpProxyBasePlugin(ABC):
|
|||
pass # pragma: no cover
|
||||
|
||||
|
||||
class HttpProxyPlugin(HttpProtocolBasePlugin):
|
||||
"""HttpProtocolHandler plugin which implements HttpProxy specifications."""
|
||||
class HttpProxyPlugin(ProtocolHandlerPlugin):
|
||||
"""ProtocolHandler plugin which implements HttpProxy specifications."""
|
||||
|
||||
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = HttpParser.build_response(
|
||||
200, reason=b'Connection established'
|
||||
|
@ -1000,7 +996,7 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection,
|
||||
request: HttpParser):
|
||||
super().__init__(config, client, request)
|
||||
|
@ -1020,16 +1016,16 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
|
||||
r: List[socket.socket] = []
|
||||
w: List[socket.socket] = []
|
||||
if self.server and not self.server.closed and self.server.conn:
|
||||
r.append(self.server.conn)
|
||||
if self.server and not self.server.closed and self.server.has_buffer() and self.server.conn:
|
||||
w.append(self.server.conn)
|
||||
if self.server and not self.server.closed and self.server.connection:
|
||||
r.append(self.server.connection)
|
||||
if self.server and not self.server.closed and self.server.has_buffer() and self.server.connection:
|
||||
w.append(self.server.connection)
|
||||
return r, w, []
|
||||
|
||||
def flush_to_descriptors(self, w: List[socket.socket]) -> bool:
|
||||
if self.request.has_upstream_server() and \
|
||||
self.server and not self.server.closed and self.server.conn in w:
|
||||
logger.debug('Server is ready for writes, flushing server buffer')
|
||||
self.server and not self.server.closed and self.server.connection in w:
|
||||
logger.debug('Server is write ready, flushing buffer')
|
||||
try:
|
||||
self.server.flush()
|
||||
except BrokenPipeError:
|
||||
|
@ -1040,10 +1036,10 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
|
||||
def read_from_descriptors(self, r: List[socket.socket]) -> bool:
|
||||
if self.request.has_upstream_server(
|
||||
) and self.server and not self.server.closed and self.server.conn in r:
|
||||
) and self.server and not self.server.closed and self.server.connection in r:
|
||||
logger.debug('Server is ready for reads, reading')
|
||||
raw = self.server.recv(self.config.server_recvbuf_size)
|
||||
# self.last_activity = HttpProtocolHandler.now()
|
||||
# self.last_activity = ProtocolHandler.now()
|
||||
if not raw:
|
||||
logger.debug('Server closed connection, tearing down...')
|
||||
return True
|
||||
|
@ -1148,21 +1144,22 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
# connection to the client. Below we handle the scenario
|
||||
# when client is communicating to proxy.py using http.
|
||||
if not (self.config.keyfile and self.config.certfile) and \
|
||||
self.server and isinstance(self.server.conn, socket.socket):
|
||||
self.client.conn = ssl.wrap_socket(self.client.conn,
|
||||
server_side=True,
|
||||
keyfile=self.config.ca_signing_key_file,
|
||||
certfile=generated_cert)
|
||||
self.server and isinstance(self.server.connection, socket.socket):
|
||||
self.client._conn = ssl.wrap_socket(
|
||||
self.client.connection,
|
||||
server_side=True,
|
||||
keyfile=self.config.ca_signing_key_file,
|
||||
certfile=generated_cert)
|
||||
# Wrap our connection to upstream server connection
|
||||
ctx = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH)
|
||||
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
self.server.conn = ctx.wrap_socket(
|
||||
self.server.conn, server_hostname=text_(
|
||||
self.server._conn = ctx.wrap_socket(
|
||||
self.server.connection, server_hostname=text_(
|
||||
self.request.host))
|
||||
logger.info(
|
||||
'TLS interception using %s', generated_cert)
|
||||
return self.client.conn
|
||||
return self.client.connection
|
||||
# for general http requests, re-build request packet
|
||||
# and queue for the server with appropriate headers
|
||||
elif self.server:
|
||||
|
@ -1232,7 +1229,7 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
raise ProxyConnectionFailed(text_(host), port, repr(e)) from e
|
||||
else:
|
||||
logger.exception('Both host and port must exist')
|
||||
raise HttpProtocolException()
|
||||
raise ProtocolException()
|
||||
|
||||
|
||||
class HttpWebServerRoutePlugin(ABC):
|
||||
|
@ -1240,7 +1237,7 @@ class HttpWebServerRoutePlugin(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection):
|
||||
self.config = config
|
||||
self.client = client
|
||||
|
@ -1279,7 +1276,7 @@ class HttpWebServerPacFilePlugin(HttpWebServerRoutePlugin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection):
|
||||
super().__init__(config, client)
|
||||
self.pac_file_response: Optional[bytes] = None
|
||||
|
@ -1312,8 +1309,8 @@ class HttpWebServerPacFilePlugin(HttpWebServerRoutePlugin):
|
|||
)
|
||||
|
||||
|
||||
class HttpWebServerPlugin(HttpProtocolBasePlugin):
|
||||
"""HttpProtocolHandler plugin which handles incoming requests to local webserver."""
|
||||
class HttpWebServerPlugin(ProtocolHandlerPlugin):
|
||||
"""ProtocolHandler plugin which handles incoming requests to local webserver."""
|
||||
|
||||
DEFAULT_404_RESPONSE = HttpParser.build_response(
|
||||
404, reason=b'NOT FOUND',
|
||||
|
@ -1323,7 +1320,7 @@ class HttpWebServerPlugin(HttpProtocolBasePlugin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: HttpProtocolConfig,
|
||||
config: ProtocolConfig,
|
||||
client: TcpClientConnection,
|
||||
request: HttpParser):
|
||||
super().__init__(config, client, request)
|
||||
|
@ -1392,28 +1389,62 @@ class HttpWebServerPlugin(HttpProtocolBasePlugin):
|
|||
return [], [], []
|
||||
|
||||
|
||||
class HttpProtocolHandler(threading.Thread):
|
||||
class ProtocolHandler(threading.Thread):
|
||||
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
||||
|
||||
Accepts `Client` connection object and manages HttpProtocolBasePlugin invocations.
|
||||
Accepts `Client` connection object and manages ProtocolHandlerPlugin invocations.
|
||||
"""
|
||||
|
||||
def __init__(self, client: TcpClientConnection,
|
||||
config: Optional[HttpProtocolConfig] = None):
|
||||
def __init__(self, fileno: int, addr: Tuple[str, int],
|
||||
config: Optional[ProtocolConfig] = None):
|
||||
super().__init__()
|
||||
self.start_time: datetime.datetime = self.now()
|
||||
self.last_activity: datetime.datetime = self.start_time
|
||||
|
||||
self.client: TcpClientConnection = client
|
||||
self.config: HttpProtocolConfig = config if config else HttpProtocolConfig()
|
||||
self.config: ProtocolConfig = config if config else ProtocolConfig()
|
||||
self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER)
|
||||
|
||||
self.plugins: Dict[str, HttpProtocolBasePlugin] = {}
|
||||
if b'HttpProtocolBasePlugin' in self.config.plugins:
|
||||
for klass in self.config.plugins[b'HttpProtocolBasePlugin']:
|
||||
conn = self.optionally_wrap_socket(self.fromfd(fileno))
|
||||
if conn is None:
|
||||
raise TcpConnectionUninitializedException()
|
||||
self.client: TcpClientConnection = TcpClientConnection(
|
||||
conn=conn,
|
||||
addr=addr)
|
||||
|
||||
self.plugins: Dict[str, ProtocolHandlerPlugin] = {}
|
||||
if b'ProtocolHandlerPlugin' in self.config.plugins:
|
||||
for klass in self.config.plugins[b'ProtocolHandlerPlugin']:
|
||||
instance = klass(self.config, self.client, self.request)
|
||||
self.plugins[instance.name()] = instance
|
||||
|
||||
def fromfd(self, fileno: int) -> socket.socket:
|
||||
return socket.fromfd(
|
||||
fileno, family=socket.AF_INET if self.config.hostname.version == 4 else socket.AF_INET6,
|
||||
type=socket.SOCK_STREAM)
|
||||
|
||||
def optionally_wrap_socket(self, conn: socket.socket) -> Optional[Union[ssl.SSLSocket, socket.socket]]:
|
||||
"""Attempts to wrap accepted client connection using provided certificates.
|
||||
|
||||
Shutdown and closes client connection upon error.
|
||||
"""
|
||||
if self.config.certfile and self.config.keyfile:
|
||||
try:
|
||||
ctx = ssl.create_default_context(
|
||||
ssl.Purpose.CLIENT_AUTH)
|
||||
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
ctx.load_cert_chain(
|
||||
certfile=self.config.certfile,
|
||||
keyfile=self.config.keyfile)
|
||||
conn = ctx.wrap_socket(conn, server_side=True)
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.exception('Error encountered', exc_info=e)
|
||||
conn.shutdown(socket.SHUT_RDWR)
|
||||
conn.close()
|
||||
return None
|
||||
return conn
|
||||
|
||||
@staticmethod
|
||||
def now() -> datetime.datetime:
|
||||
return datetime.datetime.utcnow()
|
||||
|
@ -1426,8 +1457,8 @@ class HttpProtocolHandler(threading.Thread):
|
|||
return self.connection_inactive_for() > 30
|
||||
|
||||
def handle_writables(self, writables: List[socket.socket]) -> bool:
|
||||
if self.client.conn in writables:
|
||||
logger.debug('Client is ready for writes, flushing client buffer')
|
||||
if self.client.connection in writables:
|
||||
logger.debug('Client is write, flushing buffer')
|
||||
try:
|
||||
self.client.flush()
|
||||
except BrokenPipeError:
|
||||
|
@ -1437,7 +1468,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
return False
|
||||
|
||||
def handle_readables(self, readables: List[socket.socket]) -> bool:
|
||||
if self.client.conn in readables:
|
||||
if self.client.connection in readables:
|
||||
logger.debug('Client is ready for reads, reading')
|
||||
client_data = self.client.recv(self.config.client_recvbuf_size)
|
||||
self.last_activity = self.now()
|
||||
|
@ -1446,7 +1477,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
self.client.closed = True
|
||||
return True
|
||||
|
||||
# HttpProtocolBasePlugin.on_client_data
|
||||
# ProtocolHandlerPlugin.on_client_data
|
||||
plugin_index = 0
|
||||
plugins = list(self.plugins.values())
|
||||
while plugin_index < len(plugins) and client_data:
|
||||
|
@ -1464,19 +1495,19 @@ class HttpProtocolHandler(threading.Thread):
|
|||
if isinstance(upgraded_sock, ssl.SSLSocket):
|
||||
logger.debug(
|
||||
'Updated client conn to %s', upgraded_sock)
|
||||
self.client.conn = upgraded_sock
|
||||
self.client._conn = upgraded_sock
|
||||
# Update self.client.conn references for all
|
||||
# plugins
|
||||
for plugin_ in self.plugins.values():
|
||||
if plugin_ != plugin:
|
||||
plugin_.client.conn = upgraded_sock
|
||||
plugin_.client._conn = upgraded_sock
|
||||
logger.debug(
|
||||
'Upgraded client conn for plugin %s', str(plugin_))
|
||||
elif isinstance(upgraded_sock, bool) and upgraded_sock:
|
||||
return True
|
||||
except HttpProtocolException as e:
|
||||
except ProtocolException as e:
|
||||
logger.exception(
|
||||
'HttpProtocolException type raised', exc_info=e)
|
||||
'ProtocolException type raised', exc_info=e)
|
||||
response = e.response(self.request)
|
||||
if response:
|
||||
self.client.queue(response)
|
||||
|
@ -1490,13 +1521,13 @@ class HttpProtocolHandler(threading.Thread):
|
|||
def run_once(self) -> bool:
|
||||
"""Returns True if proxy must teardown."""
|
||||
# Prepare list of descriptors
|
||||
read_desc: List[socket.socket] = [self.client.conn]
|
||||
read_desc: List[socket.socket] = [self.client.connection]
|
||||
write_desc: List[socket.socket] = []
|
||||
err_desc: List[socket.socket] = []
|
||||
if self.client.has_buffer():
|
||||
write_desc.append(self.client.conn)
|
||||
write_desc.append(self.client.connection)
|
||||
|
||||
# HttpProtocolBasePlugin.get_descriptors
|
||||
# ProtocolHandlerPlugin.get_descriptors
|
||||
for plugin in self.plugins.values():
|
||||
plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors()
|
||||
read_desc += plugin_read_desc
|
||||
|
@ -1539,7 +1570,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
return False
|
||||
|
||||
def run(self) -> None:
|
||||
logger.debug('Proxying connection %r' % self.client.conn)
|
||||
logger.debug('Proxying connection %r' % self.client.connection)
|
||||
try:
|
||||
while True:
|
||||
teardown = self.run_once()
|
||||
|
@ -1550,7 +1581,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
except Exception as e:
|
||||
logger.exception(
|
||||
'Exception while handling connection %r with reason %r' %
|
||||
(self.client.conn, e))
|
||||
(self.client.connection, e))
|
||||
finally:
|
||||
# Invoke plugin.access_log
|
||||
for plugin in self.plugins.values():
|
||||
|
@ -1558,7 +1589,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
|
||||
if not self.client.closed:
|
||||
try:
|
||||
self.client.conn.shutdown(socket.SHUT_RDWR)
|
||||
self.client.connection.shutdown(socket.SHUT_RDWR)
|
||||
self.client.close()
|
||||
except OSError:
|
||||
pass
|
||||
|
@ -1570,7 +1601,7 @@ class HttpProtocolHandler(threading.Thread):
|
|||
logger.debug(
|
||||
'Closed proxy for connection %r '
|
||||
'at address %r with pending client buffer size %d bytes' %
|
||||
(self.client.conn, self.client.addr, self.client.buffer_size()))
|
||||
(self.client.connection, self.client.addr, self.client.buffer_size()))
|
||||
|
||||
|
||||
def is_py3() -> bool:
|
||||
|
@ -1595,7 +1626,7 @@ def load_plugins(plugins: bytes) -> Dict[bytes, List[type]]:
|
|||
"""Accepts a comma separated list of Python modules and returns
|
||||
a list of respective Python classes."""
|
||||
p: Dict[bytes, List[type]] = {
|
||||
b'HttpProtocolBasePlugin': [],
|
||||
b'ProtocolHandlerPlugin': [],
|
||||
b'HttpProxyBasePlugin': [],
|
||||
b'HttpWebServerRoutePlugin': [],
|
||||
}
|
||||
|
@ -1820,7 +1851,7 @@ def main(input_args: List[str]) -> None:
|
|||
if args.basic_auth:
|
||||
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
|
||||
|
||||
config = HttpProtocolConfig(
|
||||
config = ProtocolConfig(
|
||||
auth_code=auth_code,
|
||||
server_recvbuf_size=args.server_recvbuf_size,
|
||||
client_recvbuf_size=args.client_recvbuf_size,
|
||||
|
@ -1853,7 +1884,13 @@ def main(input_args: List[str]) -> None:
|
|||
'%s%s' %
|
||||
(default_plugins, args.plugins)))
|
||||
|
||||
server = MultiCoreRequestDispatcher(config=config)
|
||||
server = WorkerPool(
|
||||
hostname=config.hostname,
|
||||
port=config.port,
|
||||
backlog=config.backlog,
|
||||
num_workers=config.num_workers,
|
||||
work_klass=ProtocolHandler,
|
||||
config=config)
|
||||
if args.pid_file:
|
||||
with open(args.pid_file, 'wb') as pid_file:
|
||||
pid_file.write(bytes_(str(os.getpid())))
|
||||
|
|
506
tests.py
506
tests.py
|
@ -20,7 +20,7 @@ import unittest
|
|||
from contextlib import closing
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from threading import Thread
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from unittest import mock
|
||||
|
||||
import proxy
|
||||
|
@ -46,21 +46,32 @@ def get_available_port() -> int:
|
|||
|
||||
class TestTcpConnection(unittest.TestCase):
|
||||
|
||||
class TcpConnectionToTest(proxy.TcpConnection):
|
||||
|
||||
def __init__(self, conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None,
|
||||
tag: int = proxy.tcpConnectionTypes.CLIENT) -> None:
|
||||
super().__init__(tag)
|
||||
self._conn = conn
|
||||
|
||||
@property
|
||||
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||||
if self._conn is None:
|
||||
raise proxy.TcpConnectionUninitializedException()
|
||||
return self._conn
|
||||
|
||||
def testThrowsKeyErrorIfNoConn(self) -> None:
|
||||
self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT)
|
||||
self.conn.conn = None
|
||||
with self.assertRaises(KeyError):
|
||||
self.conn = TestTcpConnection.TcpConnectionToTest()
|
||||
with self.assertRaises(proxy.TcpConnectionUninitializedException):
|
||||
self.conn.send(b'dummy')
|
||||
with self.assertRaises(KeyError):
|
||||
with self.assertRaises(proxy.TcpConnectionUninitializedException):
|
||||
self.conn.recv()
|
||||
with self.assertRaises(KeyError):
|
||||
with self.assertRaises(proxy.TcpConnectionUninitializedException):
|
||||
self.conn.close()
|
||||
|
||||
def testHandlesIOError(self) -> None:
|
||||
self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT)
|
||||
_conn = mock.MagicMock()
|
||||
_conn.recv.side_effect = IOError()
|
||||
self.conn.conn = _conn
|
||||
self.conn = TestTcpConnection.TcpConnectionToTest(_conn)
|
||||
with mock.patch('proxy.logger') as mock_logger:
|
||||
self.conn.recv()
|
||||
mock_logger.exception.assert_called()
|
||||
|
@ -68,12 +79,11 @@ class TestTcpConnection(unittest.TestCase):
|
|||
'Exception while receiving from connection'))
|
||||
|
||||
def testHandlesConnReset(self) -> None:
|
||||
self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT)
|
||||
_conn = mock.MagicMock()
|
||||
e = IOError()
|
||||
e.errno = errno.ECONNRESET
|
||||
_conn.recv.side_effect = e
|
||||
self.conn.conn = _conn
|
||||
self.conn = TestTcpConnection.TcpConnectionToTest(_conn)
|
||||
with mock.patch('proxy.logger') as mock_logger:
|
||||
self.conn.recv()
|
||||
mock_logger.exception.assert_not_called()
|
||||
|
@ -81,58 +91,46 @@ class TestTcpConnection(unittest.TestCase):
|
|||
self.assertEqual(mock_logger.debug.call_args[0][0], '%r' % e)
|
||||
|
||||
def testClosesIfNotClosed(self) -> None:
|
||||
self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT)
|
||||
_conn = mock.MagicMock()
|
||||
self.conn.conn = _conn
|
||||
self.conn = TestTcpConnection.TcpConnectionToTest(_conn)
|
||||
self.conn.close()
|
||||
_conn.close.assert_called()
|
||||
self.assertTrue(self.conn.closed)
|
||||
|
||||
def testNoOpIfAlreadyClosed(self) -> None:
|
||||
self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT)
|
||||
_conn = mock.MagicMock()
|
||||
self.conn.conn = _conn
|
||||
self.conn = TestTcpConnection.TcpConnectionToTest(_conn)
|
||||
self.conn.closed = True
|
||||
self.conn.close()
|
||||
_conn.close.assert_not_called()
|
||||
self.assertTrue(self.conn.closed)
|
||||
|
||||
@mock.patch('socket.socket')
|
||||
def testTcpServerClosesConnOnGC(self, mock_socket: mock.Mock) -> None:
|
||||
conn = mock.MagicMock()
|
||||
mock_socket.return_value = conn
|
||||
self.conn = proxy.TcpServerConnection(
|
||||
str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT)
|
||||
self.conn.connect()
|
||||
del self.conn
|
||||
conn.close.assert_called()
|
||||
|
||||
@mock.patch('socket.socket')
|
||||
def testTcpServerEstablishesIPv6Connection(self, mock_socket: mock.Mock) -> None:
|
||||
self.conn = proxy.TcpServerConnection(
|
||||
conn = proxy.TcpServerConnection(
|
||||
str(proxy.DEFAULT_IPV6_HOSTNAME), proxy.DEFAULT_PORT)
|
||||
self.conn.connect()
|
||||
conn.connect()
|
||||
mock_socket.assert_called()
|
||||
mock_socket.return_value.connect.assert_called_with(
|
||||
(str(proxy.DEFAULT_IPV6_HOSTNAME), proxy.DEFAULT_PORT, 0, 0))
|
||||
|
||||
@mock.patch('socket.socket')
|
||||
def testTcpServerEstablishesIPv4Connection(self, mock_socket: mock.Mock) -> None:
|
||||
self.conn = proxy.TcpServerConnection(
|
||||
conn = proxy.TcpServerConnection(
|
||||
str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT)
|
||||
self.conn.connect()
|
||||
conn.connect()
|
||||
mock_socket.assert_called()
|
||||
mock_socket.return_value.connect.assert_called_with(
|
||||
(str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT))
|
||||
|
||||
|
||||
class BasicTcpServer(proxy.TcpServer):
|
||||
class TcpServerUnderTest(proxy.TcpServer):
|
||||
|
||||
def handle(self, client: proxy.TcpClientConnection) -> None:
|
||||
data = client.recv(proxy.DEFAULT_BUFFER_SIZE)
|
||||
if data != b'HELLO':
|
||||
raise ValueError('Expected HELLO')
|
||||
client.conn.sendall(b'WORLD')
|
||||
client.connection.sendall(b'WORLD')
|
||||
client.close()
|
||||
|
||||
def setup(self) -> None:
|
||||
|
@ -148,8 +146,8 @@ class TestTcpServerIntegration(unittest.TestCase):
|
|||
|
||||
ipv4_port: Optional[int] = None
|
||||
ipv6_port: Optional[int] = None
|
||||
ipv4_server: Optional[BasicTcpServer] = None
|
||||
ipv6_server: Optional[BasicTcpServer] = None
|
||||
ipv4_server: Optional[TcpServerUnderTest] = None
|
||||
ipv6_server: Optional[TcpServerUnderTest] = None
|
||||
ipv4_thread: Optional[Thread] = None
|
||||
ipv6_thread: Optional[Thread] = None
|
||||
|
||||
|
@ -157,14 +155,12 @@ class TestTcpServerIntegration(unittest.TestCase):
|
|||
def setUpClass(cls) -> None:
|
||||
cls.ipv4_port = get_available_port()
|
||||
cls.ipv6_port = get_available_port()
|
||||
cls.ipv4_server = BasicTcpServer(
|
||||
cls.ipv4_server = TcpServerUnderTest(
|
||||
hostname=proxy.DEFAULT_IPV4_HOSTNAME,
|
||||
port=cls.ipv4_port,
|
||||
family=socket.AF_INET)
|
||||
cls.ipv6_server = BasicTcpServer(
|
||||
port=cls.ipv4_port)
|
||||
cls.ipv6_server = TcpServerUnderTest(
|
||||
hostname=proxy.DEFAULT_IPV6_HOSTNAME,
|
||||
port=cls.ipv6_port,
|
||||
family=socket.AF_INET6)
|
||||
port=cls.ipv6_port)
|
||||
cls.ipv4_thread = Thread(target=cls.ipv4_server.run)
|
||||
cls.ipv6_thread = Thread(target=cls.ipv6_server.run)
|
||||
cls.ipv4_thread.setDaemon(True)
|
||||
|
@ -221,7 +217,7 @@ class TestTcpServer(unittest.TestCase):
|
|||
def testAcceptSSLErrorsSilentlyIgnored(self, mock_socket: mock.Mock, mock_select: mock.Mock) -> None:
|
||||
mock_socket.accept.side_effect = ssl.SSLError()
|
||||
mock_select.return_value = ([mock_socket], [], [])
|
||||
server = BasicTcpServer(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=1234)
|
||||
server = TcpServerUnderTest(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=1234)
|
||||
server.socket = mock_socket
|
||||
with mock.patch('proxy.logger') as mock_logger:
|
||||
server.run_once()
|
||||
|
@ -229,71 +225,6 @@ class TestTcpServer(unittest.TestCase):
|
|||
self.assertTrue(mock_logger.exception.call_args[0][0], 'SSLError encountered')
|
||||
|
||||
|
||||
class MockHttpProxy:
|
||||
|
||||
def __init__(self, client: proxy.TcpClientConnection, **kwargs) -> None: # type: ignore
|
||||
self.client = client
|
||||
self.kwargs = kwargs
|
||||
|
||||
def setDaemon(self, _val: bool) -> None:
|
||||
pass
|
||||
|
||||
def start(self) -> None:
|
||||
self.client.conn.sendall(proxy.CRLF.join(
|
||||
[b'HTTP/1.1 200 OK', proxy.CRLF]))
|
||||
self.client.conn.shutdown(socket.SHUT_RDWR)
|
||||
self.client.conn.close()
|
||||
|
||||
|
||||
def mock_tcp_proxy_side_effect(client: proxy.TcpClientConnection, **kwargs) -> MockHttpProxy: # type: ignore
|
||||
return MockHttpProxy(client, **kwargs)
|
||||
|
||||
|
||||
@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0),
|
||||
'Opening sockets not allowed on Travis')
|
||||
@unittest.skipIf(os.getenv('GITHUB_WORKFLOW', 0),
|
||||
'This test fails on GitHub Windows environment')
|
||||
class TestMultiCoreRequestDispatcherIntegration(unittest.TestCase):
|
||||
tcp_port = None
|
||||
tcp_server = None
|
||||
tcp_thread = None
|
||||
|
||||
@mock.patch.object(
|
||||
proxy, # type: ignore
|
||||
'HttpProtocolHandler',
|
||||
side_effect=mock_tcp_proxy_side_effect)
|
||||
def testHttpProxyConnection(self, _mock_tcp_proxy):
|
||||
try:
|
||||
self.tcp_port = get_available_port()
|
||||
self.tcp_server = proxy.MultiCoreRequestDispatcher(
|
||||
proxy.HttpProtocolConfig(
|
||||
hostname=proxy.DEFAULT_IPV4_HOSTNAME,
|
||||
port=self.tcp_port,
|
||||
num_workers=1))
|
||||
self.tcp_thread = Thread(target=self.tcp_server.run)
|
||||
self.tcp_thread.setDaemon(True)
|
||||
self.tcp_thread.start()
|
||||
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock:
|
||||
try:
|
||||
sock.connect(
|
||||
(str(proxy.DEFAULT_IPV4_HOSTNAME), self.tcp_port))
|
||||
sock.send(proxy.HttpParser.build_request(
|
||||
b'GET', b'http://httpbin.org/get', b'HTTP/1.1',
|
||||
headers={b'Host': b'httpbin.org'})
|
||||
)
|
||||
data = sock.recv(proxy.DEFAULT_BUFFER_SIZE)
|
||||
self.assertEqual(data, proxy.CRLF.join(
|
||||
[b'HTTP/1.1 200 OK', proxy.CRLF]))
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self.tcp_server.stop()
|
||||
self.tcp_thread.join()
|
||||
|
||||
|
||||
class TestChunkParser(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
@ -770,61 +701,31 @@ class TestHttpParser(unittest.TestCase):
|
|||
self.assertTrue(k in dictionary)
|
||||
|
||||
|
||||
# TODO: Replace MockTcpSocket with mock.MagicMock instances.
|
||||
class MockTcpSocket(socket.socket):
|
||||
|
||||
def __init__(self, buf: bytes = b'') -> None:
|
||||
self.buffer = buf
|
||||
self.received = b''
|
||||
self.closed = False
|
||||
super().__init__()
|
||||
|
||||
def recv(self, b: int = 8192, flags: Optional[int] = None) -> bytes:
|
||||
data = self.buffer[:b]
|
||||
self.buffer = self.buffer[b:]
|
||||
return data
|
||||
|
||||
def send(self, data: bytes, flags: Optional[int] = None) -> int:
|
||||
self.received += data
|
||||
return len(data)
|
||||
|
||||
def queue(self, data: bytes) -> None:
|
||||
self.buffer += data
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
super().close()
|
||||
|
||||
def shutdown(self, _how: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self.send_response(200)
|
||||
# TODO(abhinavsingh): Proxy should work just fine even without
|
||||
# content-length header
|
||||
self.send_header('content-length', '2')
|
||||
self.end_headers()
|
||||
self.wfile.write(b'OK')
|
||||
|
||||
|
||||
class TestHttpProtocolHandler(unittest.TestCase):
|
||||
http_server = None
|
||||
http_server_port = None
|
||||
http_server_thread = None
|
||||
config = None
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_GET(self) -> None:
|
||||
self.send_response(200)
|
||||
# TODO(abhinavsingh): Proxy should work just fine even without
|
||||
# content-length header
|
||||
self.send_header('content-length', '2')
|
||||
self.end_headers()
|
||||
self.wfile.write(b'OK')
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.http_server_port = get_available_port()
|
||||
cls.http_server = HTTPServer(
|
||||
('127.0.0.1', cls.http_server_port), HTTPRequestHandler)
|
||||
('127.0.0.1', cls.http_server_port), TestHttpProtocolHandler.HTTPRequestHandler)
|
||||
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
|
||||
cls.http_server_thread.setDaemon(True)
|
||||
cls.http_server_thread.start()
|
||||
cls.config = proxy.HttpProtocolConfig()
|
||||
cls.config = proxy.ProtocolConfig()
|
||||
cls.config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
|
||||
|
@ -836,15 +737,12 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
if cls.http_server_thread:
|
||||
cls.http_server_thread.join()
|
||||
|
||||
def setUp(self) -> None:
|
||||
self._conn = MockTcpSocket()
|
||||
@mock.patch('socket.fromfd')
|
||||
def setUp(self, mock_fromfd: mock.Mock) -> None:
|
||||
self.fileno = 10
|
||||
self._addr = ('127.0.0.1', 54382)
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=self.config)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self._conn.close()
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=self.config)
|
||||
|
||||
@mock.patch('select.select')
|
||||
@mock.patch('proxy.TcpServerConnection')
|
||||
|
@ -852,13 +750,14 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
server = mock_server_connection.return_value
|
||||
server.connect.return_value = True
|
||||
mock_select.side_effect = [
|
||||
([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
|
||||
([self._conn], [], []),
|
||||
([self._conn], [], []),
|
||||
([], [server.connection], [])]
|
||||
|
||||
# Send request line
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue( # type: ignore
|
||||
(b'GET http://localhost:%d HTTP/1.1' %
|
||||
self.http_server_port) + proxy.CRLF)
|
||||
self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' %
|
||||
self.http_server_port) + proxy.CRLF
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
|
@ -869,20 +768,20 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
|
||||
# Send headers and blank line, thus completing HTTP request
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'User-Agent: proxy.py/%s' % proxy.version,
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
b'Accept: */*',
|
||||
b'Proxy-Connection: Keep-Alive',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.assert_data_queued(mock_server_connection, server)
|
||||
self.proxy.run_once()
|
||||
server.flush.assert_called_once()
|
||||
|
||||
def assert_tunnel_response(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None:
|
||||
self.proxy.run_once()
|
||||
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) # type: ignore
|
||||
self.assertTrue(self.proxy.plugins['HttpProxyPlugin'].server is not None) # type: ignore
|
||||
self.assertEqual(
|
||||
self.proxy.client.buffer,
|
||||
proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
|
@ -903,17 +802,21 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
server = mock_server_connection.return_value
|
||||
server.connect.return_value = True
|
||||
server.has_buffer.side_effect = [False, False, False, True]
|
||||
mock_select.side_effect = [([self._conn], [], []), ([], [self._conn], []),
|
||||
([self._conn], [], []), ([], [server.conn], [])]
|
||||
mock_select.side_effect = [
|
||||
([self._conn], [], []), # client read ready
|
||||
([], [self._conn], []), # client write ready
|
||||
([self._conn], [], []), # client read ready
|
||||
([], [server.connection], []) # server write ready
|
||||
]
|
||||
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
b'User-Agent: proxy.py/%s' % proxy.version,
|
||||
b'Proxy-Connection: Keep-Alive',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.assert_tunnel_response(mock_server_connection, server)
|
||||
|
||||
# Dispatch tunnel established response to client
|
||||
|
@ -927,105 +830,111 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
@mock.patch('select.select')
|
||||
def test_proxy_connection_failed(self, mock_select: mock.Mock) -> None:
|
||||
mock_select.return_value = ([self._conn], [], [])
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'GET http://unknown.domain HTTP/1.1',
|
||||
b'Host: unknown.domain',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.client.conn.received, # type: ignore
|
||||
proxy.ProxyConnectionFailed.RESPONSE_PKT)
|
||||
received = self._conn.send.call_args[0][0]
|
||||
self.assertEqual(received, proxy.ProxyConnectionFailed.RESPONSE_PKT)
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
def test_proxy_authentication_failed(self, mock_select: mock.Mock) -> None:
|
||||
def test_proxy_authentication_failed(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None:
|
||||
self._conn = mock_fromfd.return_value
|
||||
mock_select.return_value = ([self._conn], [], [])
|
||||
config = proxy.HttpProtocolConfig(
|
||||
config = proxy.ProtocolConfig(
|
||||
auth_code=b'Basic %s' %
|
||||
base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=config)
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config)
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'GET http://abhinavsingh.com HTTP/1.1',
|
||||
b'Host: abhinavsingh.com',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.client.conn.received, # type: ignore
|
||||
self._conn.send.call_args[0][0],
|
||||
proxy.ProxyAuthenticationFailed.RESPONSE_PKT)
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
@mock.patch('proxy.TcpServerConnection')
|
||||
def test_authenticated_proxy_http_get(
|
||||
self, mock_server_connection: mock.Mock, mock_select: mock.Mock) -> None:
|
||||
self, mock_server_connection: mock.Mock,
|
||||
mock_select: mock.Mock,
|
||||
mock_fromfd: mock.Mock) -> None:
|
||||
self._conn = mock_fromfd.return_value
|
||||
mock_select.return_value = ([self._conn], [], [])
|
||||
server = mock_server_connection.return_value
|
||||
server.connect.return_value = True
|
||||
|
||||
client = proxy.TcpClientConnection(self._conn, self._addr)
|
||||
config = proxy.HttpProtocolConfig(
|
||||
config = proxy.ProtocolConfig(
|
||||
auth_code=b'Basic %s' %
|
||||
base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
|
||||
self.proxy = proxy.HttpProtocolHandler(client, config=config)
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, addr=self._addr, config=config)
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue( # type: ignore
|
||||
b'GET http://localhost:%d HTTP/1.1' %
|
||||
self.http_server_port)
|
||||
|
||||
self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.INITIALIZED)
|
||||
|
||||
self.proxy.client.conn.queue(proxy.CRLF) # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.LINE_RCVD)
|
||||
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'User-Agent: proxy.py/%s' % proxy.version,
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
b'Accept: */*',
|
||||
b'Proxy-Connection: Keep-Alive',
|
||||
b'Proxy-Authorization: Basic dXNlcjpwYXNz',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.assert_data_queued(mock_server_connection, server)
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
@mock.patch('proxy.TcpServerConnection')
|
||||
def test_authenticated_proxy_http_tunnel(
|
||||
self, mock_server_connection: mock.Mock, mock_select: mock.Mock) -> None:
|
||||
self, mock_server_connection: mock.Mock,
|
||||
mock_select: mock.Mock,
|
||||
mock_fromfd: mock.Mock) -> None:
|
||||
server = mock_server_connection.return_value
|
||||
server.connect.return_value = True
|
||||
mock_select.side_effect = [
|
||||
([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
|
||||
|
||||
config = proxy.HttpProtocolConfig(
|
||||
self._conn = mock_fromfd.return_value
|
||||
mock_select.side_effect = [
|
||||
([self._conn], [], []), ([self._conn], [], []), ([], [server.connection], [])]
|
||||
|
||||
config = proxy.ProtocolConfig(
|
||||
auth_code=b'Basic %s' %
|
||||
base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=config)
|
||||
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config)
|
||||
|
||||
assert self.http_server_port is not None
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
b'User-Agent: proxy.py/%s' % proxy.version,
|
||||
b'Proxy-Connection: Keep-Alive',
|
||||
b'Proxy-Authorization: Basic dXNlcjpwYXNz',
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.assert_tunnel_response(mock_server_connection, server)
|
||||
self.proxy.client.flush()
|
||||
self.assert_data_queued_to_server(server)
|
||||
|
@ -1033,73 +942,72 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
self.proxy.run_once()
|
||||
server.flush.assert_called_once()
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
def test_pac_file_served_from_disk(self, mock_select: mock.Mock) -> None:
|
||||
def test_pac_file_served_from_disk(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None:
|
||||
pac_file = 'proxy.pac'
|
||||
self._conn = mock_fromfd.return_value
|
||||
mock_select.return_value = [self._conn], [], []
|
||||
config = proxy.HttpProtocolConfig(pac_file='proxy.pac')
|
||||
self.init_and_make_pac_file_request(config)
|
||||
self.init_and_make_pac_file_request(pac_file)
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.COMPLETE)
|
||||
with open('proxy.pac', 'rb') as pac_file:
|
||||
self.assertEqual(
|
||||
self._conn.received,
|
||||
proxy.HttpParser.build_response(
|
||||
200, reason=b'OK', headers={
|
||||
b'Content-Type': b'application/x-ns-proxy-autoconfig',
|
||||
b'Connection': b'close'
|
||||
}, body=pac_file.read()
|
||||
))
|
||||
|
||||
@mock.patch('select.select')
|
||||
def test_pac_file_served_from_buffer(self, mock_select: mock.Mock) -> None:
|
||||
pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }'
|
||||
mock_select.return_value = [self._conn], [], []
|
||||
config = proxy.HttpProtocolConfig(pac_file=proxy.text_(pac_file_content))
|
||||
self.init_and_make_pac_file_request(config)
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.COMPLETE)
|
||||
self.assertEqual(
|
||||
self._conn.received,
|
||||
proxy.HttpParser.build_response(
|
||||
with open('proxy.pac', 'rb') as f:
|
||||
self._conn.send.called_once_with(proxy.HttpParser.build_response(
|
||||
200, reason=b'OK', headers={
|
||||
b'Content-Type': b'application/x-ns-proxy-autoconfig',
|
||||
b'Connection': b'close'
|
||||
}, body=pac_file_content
|
||||
}, body=f.read()
|
||||
))
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
def test_default_web_server_returns_404(self, mock_select: mock.Mock) -> None:
|
||||
def test_pac_file_served_from_buffer(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None:
|
||||
self._conn = mock_fromfd.return_value
|
||||
pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }'
|
||||
mock_select.return_value = [self._conn], [], []
|
||||
config = proxy.HttpProtocolConfig()
|
||||
self.init_and_make_pac_file_request(proxy.text_(pac_file_content))
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.COMPLETE)
|
||||
self._conn.send.called_once_with(proxy.HttpParser.build_response(
|
||||
200, reason=b'OK', headers={
|
||||
b'Content-Type': b'application/x-ns-proxy-autoconfig',
|
||||
b'Connection': b'close'
|
||||
}, body=pac_file_content
|
||||
))
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('select.select')
|
||||
def test_default_web_server_returns_404(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None:
|
||||
self._conn = mock_fromfd.return_value
|
||||
mock_select.return_value = [self._conn], [], []
|
||||
config = proxy.ProtocolConfig()
|
||||
config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=config)
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config)
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'GET /hello HTTP/1.1',
|
||||
proxy.CRLF,
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(
|
||||
self.proxy.request.state,
|
||||
proxy.httpParserStates.COMPLETE)
|
||||
self.assertEqual(
|
||||
self._conn.received,
|
||||
self._conn.send.call_args[0][0],
|
||||
proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE)
|
||||
|
||||
def test_on_client_connection_called_on_teardown(self) -> None:
|
||||
config = proxy.HttpProtocolConfig()
|
||||
@mock.patch('socket.fromfd')
|
||||
def test_on_client_connection_called_on_teardown(self, mock_fromfd: mock.Mock) -> None:
|
||||
config = proxy.ProtocolConfig()
|
||||
plugin = mock.MagicMock()
|
||||
config.plugins = {b'HttpProtocolBasePlugin': [plugin]}
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=config)
|
||||
config.plugins = {b'ProtocolHandlerPlugin': [plugin]}
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config)
|
||||
plugin.assert_called()
|
||||
with mock.patch.object(self.proxy, 'run_once') as mock_run_once:
|
||||
mock_run_once.return_value = True
|
||||
|
@ -1108,17 +1016,15 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
plugin.return_value.access_log.assert_called()
|
||||
plugin.return_value.on_client_connection_close.assert_called()
|
||||
|
||||
def init_and_make_pac_file_request(self, config: proxy.HttpProtocolConfig) -> None:
|
||||
def init_and_make_pac_file_request(self, pac_file: str) -> None:
|
||||
config = proxy.ProtocolConfig(pac_file=pac_file)
|
||||
config.plugins = proxy.load_plugins(
|
||||
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,proxy.HttpWebServerPacFilePlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(
|
||||
proxy.TcpClientConnection(
|
||||
self._conn, self._addr), config=config)
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config)
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'GET / HTTP/1.1',
|
||||
proxy.CRLF,
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
|
||||
def assert_data_queued(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None:
|
||||
self.proxy.run_once()
|
||||
|
@ -1140,14 +1046,14 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
|
||||
def assert_data_queued_to_server(self, server: mock.Mock) -> None:
|
||||
assert self.http_server_port is not None
|
||||
self.assertEqual(self.proxy.client.buffer_size(), 0)
|
||||
self.assertEqual(self._conn.send.call_args[0][0], proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore
|
||||
self._conn.recv.return_value = proxy.CRLF.join([
|
||||
b'GET / HTTP/1.1',
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
b'User-Agent: proxy.py/%s' % proxy.version,
|
||||
proxy.CRLF
|
||||
]))
|
||||
])
|
||||
self.proxy.run_once()
|
||||
server.queue.assert_called_once_with(proxy.CRLF.join([
|
||||
b'GET / HTTP/1.1',
|
||||
|
@ -1160,33 +1066,26 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
|
||||
class TestWorker(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
@mock.patch('proxy.ProtocolHandler')
|
||||
def setUp(self, mock_protocol_handler: mock.Mock) -> None:
|
||||
self.pipe = multiprocessing.Pipe()
|
||||
self.worker = proxy.Worker(self.pipe[1], proxy.HttpProtocolConfig())
|
||||
self.worker = proxy.Worker(self.pipe[1], mock_protocol_handler, config=proxy.ProtocolConfig())
|
||||
self.mock_protocol_handler = mock_protocol_handler
|
||||
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
@mock.patch('proxy.ProtocolHandler')
|
||||
def test_shutdown_op(self, mock_http_proxy: mock.Mock) -> None:
|
||||
self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None))
|
||||
self.worker.run()
|
||||
self.assertFalse(mock_http_proxy.called)
|
||||
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('proxy.recv_handle')
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
def test_spawns_http_proxy_threads(
|
||||
self,
|
||||
mock_http_proxy: mock.Mock,
|
||||
mock_recv_handle: mock.Mock,
|
||||
mock_fromfd: mock.Mock) -> None:
|
||||
def test_spawns_http_proxy_threads(self, mock_recv_handle: mock.Mock) -> None:
|
||||
fileno = 10
|
||||
mock_recv_handle.return_value = fileno
|
||||
self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None))
|
||||
self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None))
|
||||
self.worker.run()
|
||||
self.assertTrue(mock_fromfd.called)
|
||||
mock_fromfd.assert_called_with(
|
||||
fileno, family=socket.AF_INET6, type=socket.SOCK_STREAM)
|
||||
self.assertTrue(mock_http_proxy.called)
|
||||
self.assertTrue(self.mock_protocol_handler.called)
|
||||
|
||||
def test_handles_work_queue_recv_connection_refused(self) -> None:
|
||||
with mock.patch.object(self.worker.work_queue, 'recv') as mock_recv:
|
||||
|
@ -1194,61 +1093,6 @@ class TestWorker(unittest.TestCase):
|
|||
self.assertFalse(self.worker.run_once()) # doesn't teardown
|
||||
|
||||
|
||||
class TestWorkerSSLWrap(unittest.TestCase):
|
||||
|
||||
CERTFILE = 'my-https-cert.pem'
|
||||
KEYFILE = 'my-https-key.pem'
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.pipe = multiprocessing.Pipe()
|
||||
self.worker = proxy.Worker(self.pipe[1],
|
||||
proxy.HttpProtocolConfig(certfile=self.CERTFILE, keyfile=self.KEYFILE))
|
||||
|
||||
@mock.patch('ssl.create_default_context')
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('proxy.recv_handle')
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
def test_worker_performs_ssl_wrap(
|
||||
self,
|
||||
mock_http_proxy: mock.Mock,
|
||||
mock_recv_handle: mock.Mock,
|
||||
mock_fromfd: mock.Mock,
|
||||
mock_context: mock.Mock) -> None:
|
||||
fileno = 10
|
||||
mock_recv_handle.return_value = fileno
|
||||
self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None))
|
||||
self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None))
|
||||
self.worker.run()
|
||||
mock_fromfd.assert_called()
|
||||
mock_context.assert_called_with(ssl.Purpose.CLIENT_AUTH)
|
||||
mock_context.return_value.load_cert_chain.assert_called_with(certfile=self.CERTFILE, keyfile=self.KEYFILE)
|
||||
mock_http_proxy.assert_called()
|
||||
|
||||
@mock.patch('ssl.create_default_context')
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('proxy.recv_handle')
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
def test_client_conn_closed_on_os_error(
|
||||
self,
|
||||
mock_http_proxy: mock.Mock,
|
||||
mock_recv_handle: mock.Mock,
|
||||
mock_fromfd: mock.Mock,
|
||||
mock_context: mock.Mock) -> None:
|
||||
fileno = 10
|
||||
mock_recv_handle.return_value = fileno
|
||||
mock_context.return_value.wrap_socket.side_effect = OSError()
|
||||
self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None))
|
||||
self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None))
|
||||
with mock.patch('proxy.logger') as mock_logger:
|
||||
self.worker.run()
|
||||
mock_logger.exception.assert_called()
|
||||
self.assertEqual(mock_logger.exception.call_args[0][0],
|
||||
'OSError encountered while ssl wrapping the client socket')
|
||||
mock_fromfd.assert_called()
|
||||
mock_fromfd.return_value.close.assert_called()
|
||||
mock_http_proxy.assert_not_called()
|
||||
|
||||
|
||||
class TestHttpRequestRejected(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
@ -1282,7 +1126,7 @@ class TestHttpRequestRejected(unittest.TestCase):
|
|||
class TestMain(unittest.TestCase):
|
||||
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
@mock.patch('proxy.logging.basicConfig')
|
||||
def test_log_file_setup(
|
||||
self,
|
||||
|
@ -1305,7 +1149,7 @@ class TestMain(unittest.TestCase):
|
|||
@mock.patch('os.path.exists')
|
||||
@mock.patch('builtins.open')
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
@unittest.skipIf(
|
||||
True, # type: ignore
|
||||
'This test passes while development on Intellij but fails via CLI :(')
|
||||
|
@ -1327,9 +1171,9 @@ class TestMain(unittest.TestCase):
|
|||
mock_exists.assert_called_with(pid_file)
|
||||
mock_remove.assert_called_with(pid_file)
|
||||
|
||||
@mock.patch('proxy.HttpProtocolConfig')
|
||||
@mock.patch('proxy.ProtocolConfig')
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
def test_main(
|
||||
self,
|
||||
mock_multicore_dispatcher: mock.Mock,
|
||||
|
@ -1337,8 +1181,14 @@ class TestMain(unittest.TestCase):
|
|||
mock_config: mock.Mock) -> None:
|
||||
proxy.main(['--basic-auth', 'user:pass'])
|
||||
self.assertTrue(mock_set_open_file_limit.called)
|
||||
config = mock_config.return_value
|
||||
mock_multicore_dispatcher.assert_called_with(
|
||||
config=mock_config.return_value)
|
||||
hostname=config.hostname,
|
||||
port=config.port,
|
||||
backlog=config.backlog,
|
||||
num_workers=config.num_workers,
|
||||
work_klass=proxy.ProtocolHandler,
|
||||
config=config)
|
||||
mock_config.assert_called_with(
|
||||
auth_code=b'Basic dXNlcjpwYXNz',
|
||||
client_recvbuf_size=proxy.DEFAULT_CLIENT_RECVBUF_SIZE,
|
||||
|
@ -1359,9 +1209,9 @@ class TestMain(unittest.TestCase):
|
|||
)
|
||||
|
||||
@mock.patch('builtins.print')
|
||||
@mock.patch('proxy.HttpProtocolConfig')
|
||||
@mock.patch('proxy.ProtocolConfig')
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
def test_main_version(
|
||||
self,
|
||||
mock_multicore_dispatcher: mock.Mock,
|
||||
|
@ -1376,9 +1226,9 @@ class TestMain(unittest.TestCase):
|
|||
mock_config.assert_not_called()
|
||||
|
||||
@mock.patch('builtins.print')
|
||||
@mock.patch('proxy.HttpProtocolConfig')
|
||||
@mock.patch('proxy.ProtocolConfig')
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
@mock.patch('proxy.is_py3')
|
||||
def test_main_py3_runs(
|
||||
self,
|
||||
|
@ -1396,9 +1246,9 @@ class TestMain(unittest.TestCase):
|
|||
mock_config.assert_called()
|
||||
|
||||
@mock.patch('builtins.print')
|
||||
@mock.patch('proxy.HttpProtocolConfig')
|
||||
@mock.patch('proxy.ProtocolConfig')
|
||||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.WorkerPool')
|
||||
@mock.patch('proxy.is_py3')
|
||||
@unittest.skipIf(
|
||||
True, # type: ignore
|
||||
|
|
Loading…
Reference in New Issue