408 lines
16 KiB
Python
408 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
|
|
.. spelling::
|
|
|
|
http
|
|
"""
|
|
import ssl
|
|
import time
|
|
import errno
|
|
import socket
|
|
import asyncio
|
|
import logging
|
|
import selectors
|
|
|
|
from typing import Tuple, List, Union, Optional, Dict, Any
|
|
|
|
from .plugin import HttpProtocolHandlerPlugin
|
|
from .parser import HttpParser, httpParserStates, httpParserTypes
|
|
from .exception import HttpProtocolException
|
|
|
|
from ..common.types import Readables, Writables
|
|
from ..common.utils import wrap_socket, is_threadless
|
|
from ..core.base import BaseTcpServerHandler
|
|
from ..core.connection import TcpClientConnection
|
|
from ..common.flag import flags
|
|
from ..common.constants import DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_KEY_FILE
|
|
from ..common.constants import DEFAULT_SELECTOR_SELECT_TIMEOUT, DEFAULT_TIMEOUT
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
flags.add_argument(
|
|
'--client-recvbuf-size',
|
|
type=int,
|
|
default=DEFAULT_CLIENT_RECVBUF_SIZE,
|
|
help='Default: 1 MB. Maximum amount of data received from the '
|
|
'client in a single recv() operation. Bump this '
|
|
'value for faster uploads at the expense of '
|
|
'increased RAM.',
|
|
)
|
|
flags.add_argument(
|
|
'--key-file',
|
|
type=str,
|
|
default=DEFAULT_KEY_FILE,
|
|
help='Default: None. Server key file to enable end-to-end TLS encryption with clients. '
|
|
'If used, must also pass --cert-file.',
|
|
)
|
|
flags.add_argument(
|
|
'--timeout',
|
|
type=int,
|
|
default=DEFAULT_TIMEOUT,
|
|
help='Default: ' + str(DEFAULT_TIMEOUT) +
|
|
'. Number of seconds after which '
|
|
'an inactive connection must be dropped. Inactivity is defined by no '
|
|
'data sent or received by the client.',
|
|
)
|
|
|
|
|
|
class HttpProtocolHandler(BaseTcpServerHandler):
|
|
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
|
|
|
Accepts `Client` connection and delegates to HttpProtocolHandlerPlugin.
|
|
"""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
super().__init__(*args, **kwargs)
|
|
self.start_time: float = time.time()
|
|
self.last_activity: float = self.start_time
|
|
self.request: HttpParser = HttpParser(
|
|
httpParserTypes.REQUEST_PARSER,
|
|
enable_proxy_protocol=self.flags.enable_proxy_protocol,
|
|
)
|
|
self.selector: Optional[selectors.DefaultSelector] = None
|
|
if not is_threadless(self.flags.threadless, self.flags.threaded):
|
|
self.selector = selectors.DefaultSelector()
|
|
self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {}
|
|
|
|
##
|
|
# initialize, is_inactive, shutdown, get_events, handle_events
|
|
# overrides Work class definitions.
|
|
##
|
|
|
|
def initialize(self) -> None:
|
|
"""Optionally upgrades connection to HTTPS, set ``conn`` in non-blocking mode and initializes plugins."""
|
|
conn = self._optionally_wrap_socket(self.work.connection)
|
|
conn.setblocking(False)
|
|
# Update client connection reference if connection was wrapped
|
|
if self._encryption_enabled():
|
|
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
|
|
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
|
|
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
|
|
instance: HttpProtocolHandlerPlugin = klass(
|
|
self.uid,
|
|
self.flags,
|
|
self.work,
|
|
self.request,
|
|
self.event_queue,
|
|
)
|
|
self.plugins[instance.name()] = instance
|
|
logger.debug('Handling connection %r' % self.work.connection)
|
|
|
|
def is_inactive(self) -> bool:
|
|
if not self.work.has_buffer() and \
|
|
self._connection_inactive_for() > self.flags.timeout:
|
|
return True
|
|
return False
|
|
|
|
def shutdown(self) -> None:
|
|
try:
|
|
# Flush pending buffer in threaded mode only.
|
|
# For threadless mode, BaseTcpServerHandler implements
|
|
# the must_flush_before_shutdown logic automagically.
|
|
if self.selector and self.work.has_buffer():
|
|
self._flush()
|
|
# Invoke plugin.on_client_connection_close
|
|
for plugin in self.plugins.values():
|
|
plugin.on_client_connection_close()
|
|
logger.debug(
|
|
'Closing client connection %r '
|
|
'at address %s has buffer %s' %
|
|
(self.work.connection, self.work.address, self.work.has_buffer()),
|
|
)
|
|
conn = self.work.connection
|
|
# Unwrap if wrapped before shutdown.
|
|
if self._encryption_enabled() and \
|
|
isinstance(self.work.connection, ssl.SSLSocket):
|
|
conn = self.work.connection.unwrap()
|
|
conn.shutdown(socket.SHUT_WR)
|
|
logger.debug('Client connection shutdown successful')
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
self.work.connection.close()
|
|
logger.debug('Client connection closed')
|
|
super().shutdown()
|
|
|
|
async def get_events(self) -> Dict[int, int]:
|
|
# Get default client events
|
|
events: Dict[int, int] = await super().get_events()
|
|
# HttpProtocolHandlerPlugin.get_descriptors
|
|
for plugin in self.plugins.values():
|
|
plugin_read_desc, plugin_write_desc = plugin.get_descriptors()
|
|
for rfileno in plugin_read_desc:
|
|
if rfileno not in events:
|
|
events[rfileno] = selectors.EVENT_READ
|
|
else:
|
|
events[rfileno] |= selectors.EVENT_READ
|
|
for wfileno in plugin_write_desc:
|
|
if wfileno not in events:
|
|
events[wfileno] = selectors.EVENT_WRITE
|
|
else:
|
|
events[wfileno] |= selectors.EVENT_WRITE
|
|
return events
|
|
|
|
# We override super().handle_events and never call it
|
|
async def handle_events(
|
|
self,
|
|
readables: Readables,
|
|
writables: Writables,
|
|
) -> bool:
|
|
"""Returns True if proxy must tear down."""
|
|
# Flush buffer for ready to write sockets
|
|
teardown = await self.handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
# Invoke plugin.write_to_descriptors
|
|
for plugin in self.plugins.values():
|
|
teardown = await plugin.write_to_descriptors(writables)
|
|
if teardown:
|
|
return True
|
|
# Read from ready to read sockets
|
|
teardown = await self.handle_readables(readables)
|
|
if teardown:
|
|
return True
|
|
# Invoke plugin.read_from_descriptors
|
|
for plugin in self.plugins.values():
|
|
teardown = await plugin.read_from_descriptors(readables)
|
|
if teardown:
|
|
return True
|
|
return False
|
|
|
|
def handle_data(self, data: memoryview) -> Optional[bool]:
|
|
"""Handles incoming data from client."""
|
|
if data is None:
|
|
logger.debug('Client closed connection, tearing down...')
|
|
self.work.closed = True
|
|
return True
|
|
try:
|
|
# Don't parse incoming data any further after 1st request has completed.
|
|
#
|
|
# This specially does happen for pipeline requests.
|
|
#
|
|
# Plugins can utilize on_client_data for such cases and
|
|
# apply custom logic to handle request data sent after 1st
|
|
# valid request.
|
|
if self.request.state != httpParserStates.COMPLETE:
|
|
# Parse http request
|
|
#
|
|
# TODO(abhinavsingh): Remove .tobytes after parser is
|
|
# memoryview compliant
|
|
self.request.parse(data.tobytes())
|
|
if self.request.state == httpParserStates.COMPLETE:
|
|
# Invoke plugin.on_request_complete
|
|
for plugin in self.plugins.values():
|
|
upgraded_sock = plugin.on_request_complete()
|
|
if isinstance(upgraded_sock, ssl.SSLSocket):
|
|
logger.debug(
|
|
'Updated client conn to %s', upgraded_sock,
|
|
)
|
|
self.work._conn = upgraded_sock
|
|
for plugin_ in self.plugins.values():
|
|
if plugin_ != plugin:
|
|
plugin_.client._conn = upgraded_sock
|
|
elif isinstance(upgraded_sock, bool) and upgraded_sock is True:
|
|
return True
|
|
else:
|
|
# HttpProtocolHandlerPlugin.on_client_data
|
|
# Can raise HttpProtocolException to tear down the connection
|
|
for plugin in self.plugins.values():
|
|
optional_data = plugin.on_client_data(data)
|
|
if optional_data is None:
|
|
break
|
|
data = optional_data
|
|
except HttpProtocolException as e:
|
|
logger.debug('HttpProtocolException raised')
|
|
response: Optional[memoryview] = e.response(self.request)
|
|
if response:
|
|
self.work.queue(response)
|
|
return True
|
|
return False
|
|
|
|
async def handle_writables(self, writables: Writables) -> bool:
|
|
if self.work.connection.fileno() in writables and self.work.has_buffer():
|
|
logger.debug('Client is ready for writes, flushing buffer')
|
|
self.last_activity = time.time()
|
|
|
|
# TODO(abhinavsingh): This hook could just reside within server recv block
|
|
# instead of invoking when flushed to client.
|
|
#
|
|
# Invoke plugin.on_response_chunk
|
|
chunk = self.work.buffer
|
|
for plugin in self.plugins.values():
|
|
chunk = plugin.on_response_chunk(chunk)
|
|
if chunk is None:
|
|
break
|
|
|
|
try:
|
|
# Call super() for client flush
|
|
teardown = await super().handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
except BrokenPipeError:
|
|
logger.error(
|
|
'BrokenPipeError when flushing buffer for client',
|
|
)
|
|
return True
|
|
except OSError:
|
|
logger.error('OSError when flushing buffer to client')
|
|
return True
|
|
return False
|
|
|
|
async def handle_readables(self, readables: Readables) -> bool:
|
|
if self.work.connection.fileno() in readables:
|
|
logger.debug('Client is ready for reads, reading')
|
|
self.last_activity = time.time()
|
|
try:
|
|
teardown = await super().handle_readables(readables)
|
|
if teardown:
|
|
return teardown
|
|
except ssl.SSLWantReadError: # Try again later
|
|
logger.warning(
|
|
'SSLWantReadError encountered while reading from client, will retry ...',
|
|
)
|
|
return False
|
|
except socket.error as e:
|
|
if e.errno == errno.ECONNRESET:
|
|
# Most requests for mobile devices will end up
|
|
# with client closed connection. Using `debug`
|
|
# here to avoid flooding the logs.
|
|
logger.debug('%r' % e)
|
|
else:
|
|
logger.warning(
|
|
'Exception when receiving from %s connection#%d with reason %r' %
|
|
(self.work.tag, self.work.connection.fileno(), e),
|
|
)
|
|
return True
|
|
return False
|
|
|
|
##
|
|
# Internal methods
|
|
##
|
|
|
|
def _encryption_enabled(self) -> bool:
|
|
return self.flags.keyfile is not None and \
|
|
self.flags.certfile is not None
|
|
|
|
def _optionally_wrap_socket(
|
|
self, conn: socket.socket,
|
|
) -> Union[ssl.SSLSocket, socket.socket]:
|
|
"""Attempts to wrap accepted client connection using provided certificates.
|
|
|
|
Shutdown and closes client connection upon error.
|
|
"""
|
|
if self._encryption_enabled():
|
|
assert self.flags.keyfile and self.flags.certfile
|
|
# TODO(abhinavsingh): Insecure TLS versions must not be accepted by default
|
|
conn = wrap_socket(conn, self.flags.keyfile, self.flags.certfile)
|
|
return conn
|
|
|
|
def _connection_inactive_for(self) -> float:
|
|
return time.time() - self.last_activity
|
|
|
|
##
|
|
# run() and _run_once() are here to maintain backward compatibility
|
|
# with threaded mode. These methods are only called when running
|
|
# in threaded mode.
|
|
##
|
|
|
|
def run(self) -> None:
|
|
"""run() method is not used when in --threadless mode.
|
|
|
|
This is here just to maintain backward compatibility with threaded mode.
|
|
"""
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
self.initialize()
|
|
while True:
|
|
# Tear down if client buffer is empty and connection is inactive
|
|
if self.is_inactive():
|
|
logger.debug(
|
|
'Client buffer is empty and maximum inactivity has reached '
|
|
'between client and server connection, tearing down...',
|
|
)
|
|
break
|
|
if loop.run_until_complete(self._run_once()):
|
|
break
|
|
except KeyboardInterrupt: # pragma: no cover
|
|
pass
|
|
except ssl.SSLError as e:
|
|
logger.exception('ssl.SSLError', exc_info=e)
|
|
except Exception as e:
|
|
logger.exception(
|
|
'Exception while handling connection %r' %
|
|
self.work.connection, exc_info=e,
|
|
)
|
|
finally:
|
|
self.shutdown()
|
|
loop.close()
|
|
|
|
async def _run_once(self) -> bool:
|
|
events, readables, writables = await self._selected_events()
|
|
try:
|
|
return await self.handle_events(readables, writables)
|
|
finally:
|
|
assert self.selector
|
|
# TODO: Like Threadless we should not unregister
|
|
# work fds repeatedly.
|
|
for fd in events:
|
|
self.selector.unregister(fd)
|
|
|
|
# FIXME: Returning events is only necessary because we cannot use async context manager
|
|
# for < Python 3.8. As a reason, this method is no longer a context manager and caller
|
|
# is responsible for unregistering the descriptors.
|
|
async def _selected_events(self) -> Tuple[Dict[int, int], Readables, Writables]:
|
|
assert self.selector
|
|
events = await self.get_events()
|
|
for fd in events:
|
|
self.selector.register(fd, events[fd])
|
|
ev = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT)
|
|
readables = []
|
|
writables = []
|
|
for key, mask in ev:
|
|
if mask & selectors.EVENT_READ:
|
|
readables.append(key.fileobj)
|
|
if mask & selectors.EVENT_WRITE:
|
|
writables.append(key.fileobj)
|
|
return (events, readables, writables)
|
|
|
|
def _flush(self) -> None:
|
|
assert self.selector
|
|
logger.debug('Flushing pending data')
|
|
try:
|
|
self.selector.register(
|
|
self.work.connection,
|
|
selectors.EVENT_WRITE,
|
|
)
|
|
while self.work.has_buffer():
|
|
logging.debug('Waiting for client read ready')
|
|
ev: List[
|
|
Tuple[selectors.SelectorKey, int]
|
|
] = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT)
|
|
if len(ev) == 0:
|
|
continue
|
|
self.work.flush()
|
|
except BrokenPipeError:
|
|
pass
|
|
finally:
|
|
self.selector.unregister(self.work.connection)
|