443 lines
17 KiB
Python
443 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import ssl
|
|
import time
|
|
import errno
|
|
import socket
|
|
import asyncio
|
|
import logging
|
|
import selectors
|
|
|
|
from typing import Tuple, List, Type, Union, Optional, Any
|
|
|
|
from ..common.flag import flags
|
|
from ..common.utils import wrap_socket
|
|
from ..core.base import BaseTcpServerHandler
|
|
from ..core.connection import TcpClientConnection
|
|
from ..common.types import Readables, SelectableEvents, Writables
|
|
from ..common.constants import DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_KEY_FILE
|
|
from ..common.constants import DEFAULT_SELECTOR_SELECT_TIMEOUT, DEFAULT_TIMEOUT
|
|
|
|
from .exception import HttpProtocolException
|
|
from .plugin import HttpProtocolHandlerPlugin
|
|
from .responses import BAD_REQUEST_RESPONSE_PKT
|
|
from .parser import HttpParser, httpParserStates, httpParserTypes
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
flags.add_argument(
|
|
'--client-recvbuf-size',
|
|
type=int,
|
|
default=DEFAULT_CLIENT_RECVBUF_SIZE,
|
|
help='Default: ' + str(int(DEFAULT_CLIENT_RECVBUF_SIZE / 1024)) +
|
|
' KB. Maximum amount of data received from the '
|
|
'client in a single recv() operation.',
|
|
)
|
|
flags.add_argument(
|
|
'--key-file',
|
|
type=str,
|
|
default=DEFAULT_KEY_FILE,
|
|
help='Default: None. Server key file to enable end-to-end TLS encryption with clients. '
|
|
'If used, must also pass --cert-file.',
|
|
)
|
|
flags.add_argument(
|
|
'--timeout',
|
|
type=int,
|
|
default=DEFAULT_TIMEOUT,
|
|
help='Default: ' + str(DEFAULT_TIMEOUT) +
|
|
'. Number of seconds after which '
|
|
'an inactive connection must be dropped. Inactivity is defined by no '
|
|
'data sent or received by the client.',
|
|
)
|
|
|
|
|
|
class HttpProtocolHandler(BaseTcpServerHandler):
|
|
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
|
|
|
Accepts `Client` connection and delegates to HttpProtocolHandlerPlugin.
|
|
"""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
super().__init__(*args, **kwargs)
|
|
self.start_time: float = time.time()
|
|
self.last_activity: float = self.start_time
|
|
self.request: HttpParser = HttpParser(
|
|
httpParserTypes.REQUEST_PARSER,
|
|
enable_proxy_protocol=self.flags.enable_proxy_protocol,
|
|
)
|
|
self.selector: Optional[selectors.DefaultSelector] = None
|
|
if not self.flags.threadless:
|
|
self.selector = selectors.DefaultSelector()
|
|
self.plugin: Optional[HttpProtocolHandlerPlugin] = None
|
|
|
|
##
|
|
# initialize, is_inactive, shutdown, get_events, handle_events
|
|
# overrides Work class definitions.
|
|
##
|
|
|
|
def initialize(self) -> None:
|
|
"""Optionally upgrades connection to HTTPS,
|
|
sets ``conn`` in non-blocking mode and initializes
|
|
HTTP protocol plugins.
|
|
"""
|
|
conn = self._optionally_wrap_socket(self.work.connection)
|
|
conn.setblocking(False)
|
|
# Update client connection reference if connection was wrapped
|
|
if self._encryption_enabled():
|
|
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
|
|
# self._initialize_plugins()
|
|
logger.debug('Handling connection %s' % self.work.address)
|
|
|
|
def is_inactive(self) -> bool:
|
|
if not self.work.has_buffer() and \
|
|
self._connection_inactive_for() > self.flags.timeout:
|
|
return True
|
|
return False
|
|
|
|
def shutdown(self) -> None:
|
|
try:
|
|
# Flush pending buffer in threaded mode only.
|
|
#
|
|
# For threadless mode, BaseTcpServerHandler implements
|
|
# the must_flush_before_shutdown logic automagically.
|
|
if self.selector and self.work.has_buffer():
|
|
self._flush()
|
|
# Invoke plugin.on_client_connection_close
|
|
if self.plugin:
|
|
self.plugin.on_client_connection_close()
|
|
logger.debug(
|
|
'Closing client connection %s has buffer %s' %
|
|
(self.work.address, self.work.has_buffer()),
|
|
)
|
|
conn = self.work.connection
|
|
# Unwrap if wrapped before shutdown.
|
|
if self._encryption_enabled() and \
|
|
isinstance(self.work.connection, ssl.SSLSocket):
|
|
conn = self.work.connection.unwrap()
|
|
conn.shutdown(socket.SHUT_WR)
|
|
logger.debug('Client connection shutdown successful')
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data
|
|
# could lead to an immediate reset being sent.
|
|
#
|
|
# "A host MAY implement a 'half-duplex' TCP close sequence, so that an application
|
|
# that has called CLOSE cannot continue to read data from the connection.
|
|
# If such a host issues a CLOSE call while received data is still pending in TCP,
|
|
# or if new data is received after CLOSE is called, its TCP SHOULD send a RST to
|
|
# show that data was lost."
|
|
#
|
|
self.work.connection.close()
|
|
logger.debug('Client connection closed')
|
|
super().shutdown()
|
|
|
|
async def get_events(self) -> SelectableEvents:
|
|
# Get default client events
|
|
events: SelectableEvents = await super().get_events()
|
|
# HttpProtocolHandlerPlugin.get_descriptors
|
|
if self.plugin:
|
|
plugin_read_desc, plugin_write_desc = await self.plugin.get_descriptors()
|
|
for rfileno in plugin_read_desc:
|
|
if rfileno not in events:
|
|
events[rfileno] = selectors.EVENT_READ
|
|
else:
|
|
events[rfileno] |= selectors.EVENT_READ
|
|
for wfileno in plugin_write_desc:
|
|
if wfileno not in events:
|
|
events[wfileno] = selectors.EVENT_WRITE
|
|
else:
|
|
events[wfileno] |= selectors.EVENT_WRITE
|
|
return events
|
|
|
|
# We override super().handle_events and never call it
|
|
async def handle_events(
|
|
self,
|
|
readables: Readables,
|
|
writables: Writables,
|
|
) -> bool:
|
|
"""Returns True if proxy must tear down."""
|
|
# Flush buffer for ready to write sockets
|
|
teardown = await self.handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
# Invoke plugin.write_to_descriptors
|
|
if self.plugin:
|
|
teardown = await self.plugin.write_to_descriptors(writables)
|
|
if teardown:
|
|
return True
|
|
# Read from ready to read sockets
|
|
teardown = await self.handle_readables(readables)
|
|
if teardown:
|
|
return True
|
|
# Invoke plugin.read_from_descriptors
|
|
if self.plugin:
|
|
teardown = await self.plugin.read_from_descriptors(readables)
|
|
if teardown:
|
|
return True
|
|
return False
|
|
|
|
def handle_data(self, data: memoryview) -> Optional[bool]:
|
|
"""Handles incoming data from client."""
|
|
if data is None:
|
|
logger.debug('Client closed connection, tearing down...')
|
|
self.work.closed = True
|
|
return True
|
|
try:
|
|
# Don't parse incoming data any further after 1st request has completed.
|
|
#
|
|
# This specially does happen for pipeline requests.
|
|
#
|
|
# Plugins can utilize on_client_data for such cases and
|
|
# apply custom logic to handle request data sent after 1st
|
|
# valid request.
|
|
if self.request.state != httpParserStates.COMPLETE:
|
|
if self._parse_first_request(data):
|
|
return True
|
|
else:
|
|
# HttpProtocolHandlerPlugin.on_client_data
|
|
# Can raise HttpProtocolException to tear down the connection
|
|
if self.plugin:
|
|
data = self.plugin.on_client_data(data) or data
|
|
except HttpProtocolException as e:
|
|
logger.info('HttpProtocolException: %s' % e)
|
|
response: Optional[memoryview] = e.response(self.request)
|
|
if response:
|
|
self.work.queue(response)
|
|
return True
|
|
return False
|
|
|
|
async def handle_writables(self, writables: Writables) -> bool:
|
|
if self.work.connection.fileno() in writables and self.work.has_buffer():
|
|
logger.debug('Client is write ready, flushing...')
|
|
self.last_activity = time.time()
|
|
# TODO(abhinavsingh): This hook could just reside within server recv block
|
|
# instead of invoking when flushed to client.
|
|
#
|
|
# Invoke plugin.on_response_chunk
|
|
chunk = self.work.buffer
|
|
if self.plugin:
|
|
chunk = self.plugin.on_response_chunk(chunk)
|
|
try:
|
|
# Call super() for client flush
|
|
teardown = await super().handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
except BrokenPipeError:
|
|
logger.error(
|
|
'BrokenPipeError when flushing buffer for client',
|
|
)
|
|
return True
|
|
except OSError:
|
|
logger.error('OSError when flushing buffer to client')
|
|
return True
|
|
return False
|
|
|
|
async def handle_readables(self, readables: Readables) -> bool:
|
|
if self.work.connection.fileno() in readables:
|
|
logger.debug('Client is read ready, receiving...')
|
|
self.last_activity = time.time()
|
|
try:
|
|
teardown = await super().handle_readables(readables)
|
|
if teardown:
|
|
return teardown
|
|
except ssl.SSLWantReadError: # Try again later
|
|
logger.warning(
|
|
'SSLWantReadError encountered while reading from client, will retry ...',
|
|
)
|
|
return False
|
|
except socket.error as e:
|
|
if e.errno == errno.ECONNRESET:
|
|
# Most requests for mobile devices will end up
|
|
# with client closed connection. Using `debug`
|
|
# here to avoid flooding the logs.
|
|
logger.debug('%r' % e)
|
|
else:
|
|
logger.warning(
|
|
'Exception when receiving from %s connection#%d with reason %r' %
|
|
(self.work.tag, self.work.connection.fileno(), e),
|
|
)
|
|
return True
|
|
return False
|
|
|
|
##
|
|
# Internal methods
|
|
##
|
|
|
|
def _initialize_plugin(
|
|
self,
|
|
klass: Type['HttpProtocolHandlerPlugin'],
|
|
) -> HttpProtocolHandlerPlugin:
|
|
"""Initializes passed HTTP protocol handler plugin class."""
|
|
return klass(
|
|
self.uid,
|
|
self.flags,
|
|
self.work,
|
|
self.request,
|
|
self.event_queue,
|
|
self.upstream_conn_pool,
|
|
)
|
|
|
|
def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHandlerPlugin']]:
|
|
"""Discovers and return matching HTTP handler plugin matching protocol."""
|
|
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
|
|
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
|
|
k: Type['HttpProtocolHandlerPlugin'] = klass
|
|
if protocol in k.protocols():
|
|
return k
|
|
return None
|
|
|
|
def _parse_first_request(self, data: memoryview) -> bool:
|
|
# Parse http request
|
|
#
|
|
# TODO(abhinavsingh): Remove .tobytes after parser is
|
|
# memoryview compliant
|
|
try:
|
|
self.request.parse(data.tobytes())
|
|
except Exception:
|
|
raise HttpProtocolException(
|
|
'Error when parsing request: %r' % data.tobytes(),
|
|
)
|
|
if not self.request.is_complete:
|
|
return False
|
|
# Discover which HTTP handler plugin is capable of
|
|
# handling the current incoming request
|
|
klass = self._discover_plugin_klass(
|
|
self.request.http_handler_protocol,
|
|
)
|
|
if klass is None:
|
|
# No matching protocol class found.
|
|
# Return bad request response and
|
|
# close the connection.
|
|
self.work.queue(BAD_REQUEST_RESPONSE_PKT)
|
|
return True
|
|
assert klass is not None
|
|
self.plugin = self._initialize_plugin(klass)
|
|
# Invoke plugin.on_request_complete
|
|
output = self.plugin.on_request_complete()
|
|
if isinstance(output, bool):
|
|
return output
|
|
assert isinstance(output, ssl.SSLSocket)
|
|
logger.debug(
|
|
'Updated client conn to %s', output,
|
|
)
|
|
self.work._conn = output
|
|
return False
|
|
|
|
def _encryption_enabled(self) -> bool:
|
|
return self.flags.keyfile is not None and \
|
|
self.flags.certfile is not None
|
|
|
|
def _optionally_wrap_socket(
|
|
self, conn: socket.socket,
|
|
) -> Union[ssl.SSLSocket, socket.socket]:
|
|
"""Attempts to wrap accepted client connection using provided certificates.
|
|
|
|
Shutdown and closes client connection upon error.
|
|
"""
|
|
if self._encryption_enabled():
|
|
assert self.flags.keyfile and self.flags.certfile
|
|
# TODO(abhinavsingh): Insecure TLS versions must not be accepted by default
|
|
conn = wrap_socket(conn, self.flags.keyfile, self.flags.certfile)
|
|
return conn
|
|
|
|
def _connection_inactive_for(self) -> float:
|
|
return time.time() - self.last_activity
|
|
|
|
##
|
|
# run() and _run_once() are here to maintain backward compatibility
|
|
# with threaded mode. These methods are only called when running
|
|
# in threaded mode.
|
|
##
|
|
|
|
def run(self) -> None:
|
|
"""run() method is not used when in --threadless mode.
|
|
|
|
This is here just to maintain backward compatibility with threaded mode.
|
|
"""
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
self.initialize()
|
|
while True:
|
|
# Tear down if client buffer is empty and connection is inactive
|
|
if self.is_inactive():
|
|
logger.debug(
|
|
'Client buffer is empty and maximum inactivity has reached '
|
|
'between client and server connection, tearing down...',
|
|
)
|
|
break
|
|
if loop.run_until_complete(self._run_once()):
|
|
break
|
|
except KeyboardInterrupt: # pragma: no cover
|
|
pass
|
|
except ssl.SSLError as e:
|
|
logger.exception('ssl.SSLError', exc_info=e)
|
|
except Exception as e:
|
|
logger.exception(
|
|
'Exception while handling connection %r' %
|
|
self.work.connection, exc_info=e,
|
|
)
|
|
finally:
|
|
self.shutdown()
|
|
loop.close()
|
|
|
|
async def _run_once(self) -> bool:
|
|
events, readables, writables = await self._selected_events()
|
|
try:
|
|
return await self.handle_events(readables, writables)
|
|
finally:
|
|
assert self.selector
|
|
# TODO: Like Threadless we should not unregister
|
|
# work fds repeatedly.
|
|
for fd in events:
|
|
self.selector.unregister(fd)
|
|
|
|
# FIXME: Returning events is only necessary because we cannot use async context manager
|
|
# for < Python 3.8. As a reason, this method is no longer a context manager and caller
|
|
# is responsible for unregistering the descriptors.
|
|
async def _selected_events(self) -> Tuple[SelectableEvents, Readables, Writables]:
|
|
assert self.selector
|
|
events = await self.get_events()
|
|
for fd in events:
|
|
self.selector.register(fd, events[fd])
|
|
ev = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT)
|
|
readables = []
|
|
writables = []
|
|
for key, mask in ev:
|
|
if mask & selectors.EVENT_READ:
|
|
readables.append(key.fd)
|
|
if mask & selectors.EVENT_WRITE:
|
|
writables.append(key.fd)
|
|
return (events, readables, writables)
|
|
|
|
def _flush(self) -> None:
|
|
assert self.selector
|
|
logger.debug('Flushing pending data')
|
|
try:
|
|
self.selector.register(
|
|
self.work.connection,
|
|
selectors.EVENT_WRITE,
|
|
)
|
|
while self.work.has_buffer():
|
|
logging.debug('Waiting for client read ready')
|
|
ev: List[
|
|
Tuple[selectors.SelectorKey, int]
|
|
] = self.selector.select(timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT)
|
|
if len(ev) == 0:
|
|
continue
|
|
self.work.flush()
|
|
except BrokenPipeError:
|
|
pass
|
|
finally:
|
|
self.selector.unregister(self.work.connection)
|