403 lines
15 KiB
Python
403 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import socket
|
|
import selectors
|
|
import ssl
|
|
import time
|
|
import contextlib
|
|
import errno
|
|
import logging
|
|
|
|
from typing import Tuple, List, Union, Optional, Generator, Dict, Any
|
|
|
|
from .plugin import HttpProtocolHandlerPlugin
|
|
from .parser import HttpParser, httpParserStates, httpParserTypes
|
|
from .exception import HttpProtocolException
|
|
|
|
from ..common.types import Readables, Writables
|
|
from ..common.utils import wrap_socket
|
|
from ..core.base import BaseTcpServerHandler
|
|
from ..core.connection import TcpClientConnection
|
|
from ..common.flag import flags
|
|
from ..common.constants import DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_KEY_FILE, DEFAULT_TIMEOUT
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
flags.add_argument(
|
|
'--client-recvbuf-size',
|
|
type=int,
|
|
default=DEFAULT_CLIENT_RECVBUF_SIZE,
|
|
help='Default: 1 MB. Maximum amount of data received from the '
|
|
'client in a single recv() operation. Bump this '
|
|
'value for faster uploads at the expense of '
|
|
'increased RAM.',
|
|
)
|
|
flags.add_argument(
|
|
'--key-file',
|
|
type=str,
|
|
default=DEFAULT_KEY_FILE,
|
|
help='Default: None. Server key file to enable end-to-end TLS encryption with clients. '
|
|
'If used, must also pass --cert-file.',
|
|
)
|
|
flags.add_argument(
|
|
'--timeout',
|
|
type=int,
|
|
default=DEFAULT_TIMEOUT,
|
|
help='Default: ' + str(DEFAULT_TIMEOUT) +
|
|
'. Number of seconds after which '
|
|
'an inactive connection must be dropped. Inactivity is defined by no '
|
|
'data sent or received by the client.',
|
|
)
|
|
|
|
SelectedEventsGeneratorType = Generator[
|
|
Tuple[Readables, Writables],
|
|
None, None,
|
|
]
|
|
|
|
|
|
class HttpProtocolHandler(BaseTcpServerHandler):
|
|
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
|
|
|
Accepts `Client` connection and delegates to HttpProtocolHandlerPlugin.
|
|
"""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
super().__init__(*args, **kwargs)
|
|
self.start_time: float = time.time()
|
|
self.last_activity: float = self.start_time
|
|
self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER)
|
|
self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER)
|
|
self.selector: Optional[selectors.DefaultSelector] = None
|
|
if not self.flags.threadless:
|
|
self.selector = selectors.DefaultSelector()
|
|
self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {}
|
|
|
|
##
|
|
# initialize, is_inactive, shutdown, get_events, handle_events
|
|
# overrides Work class definitions.
|
|
##
|
|
|
|
def initialize(self) -> None:
|
|
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
|
|
conn = self._optionally_wrap_socket(self.client.connection)
|
|
conn.setblocking(False)
|
|
# Update client connection reference if connection was wrapped
|
|
if self._encryption_enabled():
|
|
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
|
|
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
|
|
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
|
|
instance: HttpProtocolHandlerPlugin = klass(
|
|
self.uid,
|
|
self.flags,
|
|
self.client,
|
|
self.request,
|
|
self.event_queue,
|
|
)
|
|
self.plugins[instance.name()] = instance
|
|
logger.debug('Handling connection %r' % self.client.connection)
|
|
|
|
def is_inactive(self) -> bool:
|
|
if not self.client.has_buffer() and \
|
|
self._connection_inactive_for() > self.flags.timeout:
|
|
return True
|
|
return False
|
|
|
|
def shutdown(self) -> None:
|
|
try:
|
|
# Flush pending buffer in threaded mode only.
|
|
# For threadless mode, BaseTcpServerHandler implements
|
|
# the must_flush_before_shutdown logic automagically.
|
|
if self.selector:
|
|
self._flush()
|
|
|
|
# Invoke plugin.on_client_connection_close
|
|
for plugin in self.plugins.values():
|
|
plugin.on_client_connection_close()
|
|
|
|
logger.debug(
|
|
'Closing client connection %r '
|
|
'at address %r has buffer %s' %
|
|
(self.client.connection, self.client.addr, self.client.has_buffer()),
|
|
)
|
|
|
|
conn = self.client.connection
|
|
# Unwrap if wrapped before shutdown.
|
|
if self._encryption_enabled() and \
|
|
isinstance(self.client.connection, ssl.SSLSocket):
|
|
conn = self.client.connection.unwrap()
|
|
conn.shutdown(socket.SHUT_WR)
|
|
logger.debug('Client connection shutdown successful')
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
self.client.connection.close()
|
|
logger.debug('Client connection closed')
|
|
super().shutdown()
|
|
|
|
def get_events(self) -> Dict[socket.socket, int]:
|
|
# Get default client events
|
|
events: Dict[socket.socket, int] = super().get_events()
|
|
# HttpProtocolHandlerPlugin.get_descriptors
|
|
for plugin in self.plugins.values():
|
|
plugin_read_desc, plugin_write_desc = plugin.get_descriptors()
|
|
for r in plugin_read_desc:
|
|
if r not in events:
|
|
events[r] = selectors.EVENT_READ
|
|
else:
|
|
events[r] |= selectors.EVENT_READ
|
|
for w in plugin_write_desc:
|
|
if w not in events:
|
|
events[w] = selectors.EVENT_WRITE
|
|
else:
|
|
events[w] |= selectors.EVENT_WRITE
|
|
return events
|
|
|
|
# We override super().handle_events and never call it
|
|
def handle_events(
|
|
self,
|
|
readables: Readables,
|
|
writables: Writables,
|
|
) -> bool:
|
|
"""Returns True if proxy must teardown."""
|
|
# Flush buffer for ready to write sockets
|
|
teardown = self.handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
|
|
# Invoke plugin.write_to_descriptors
|
|
for plugin in self.plugins.values():
|
|
teardown = plugin.write_to_descriptors(writables)
|
|
if teardown:
|
|
return True
|
|
|
|
# Read from ready to read sockets
|
|
teardown = self.handle_readables(readables)
|
|
if teardown:
|
|
return True
|
|
|
|
# Invoke plugin.read_from_descriptors
|
|
for plugin in self.plugins.values():
|
|
teardown = plugin.read_from_descriptors(readables)
|
|
if teardown:
|
|
return True
|
|
|
|
return False
|
|
|
|
def handle_data(self, data: memoryview) -> Optional[bool]:
|
|
if data is None:
|
|
logger.debug('Client closed connection, tearing down...')
|
|
self.client.closed = True
|
|
return True
|
|
|
|
try:
|
|
# HttpProtocolHandlerPlugin.on_client_data
|
|
# Can raise HttpProtocolException to teardown the connection
|
|
for plugin in self.plugins.values():
|
|
optional_data = plugin.on_client_data(data)
|
|
if optional_data is None:
|
|
break
|
|
data = optional_data
|
|
# Don't parse incoming data any further after 1st request has completed.
|
|
#
|
|
# This specially does happen for pipeline requests.
|
|
#
|
|
# Plugins can utilize on_client_data for such cases and
|
|
# apply custom logic to handle request data sent after 1st
|
|
# valid request.
|
|
if data and self.request.state != httpParserStates.COMPLETE:
|
|
# Parse http request
|
|
# TODO(abhinavsingh): Remove .tobytes after parser is
|
|
# memoryview compliant
|
|
self.request.parse(data.tobytes())
|
|
if self.request.state == httpParserStates.COMPLETE:
|
|
# Invoke plugin.on_request_complete
|
|
for plugin in self.plugins.values():
|
|
upgraded_sock = plugin.on_request_complete()
|
|
if isinstance(upgraded_sock, ssl.SSLSocket):
|
|
logger.debug(
|
|
'Updated client conn to %s', upgraded_sock,
|
|
)
|
|
self.client._conn = upgraded_sock
|
|
for plugin_ in self.plugins.values():
|
|
if plugin_ != plugin:
|
|
plugin_.client._conn = upgraded_sock
|
|
elif isinstance(upgraded_sock, bool) and upgraded_sock is True:
|
|
return True
|
|
except HttpProtocolException as e:
|
|
logger.debug(
|
|
'HttpProtocolException type raised',
|
|
)
|
|
response: Optional[memoryview] = e.response(self.request)
|
|
if response:
|
|
self.client.queue(response)
|
|
return True
|
|
return False
|
|
|
|
def handle_writables(self, writables: Writables) -> bool:
|
|
if self.client.connection in writables and self.client.has_buffer():
|
|
logger.debug('Client is ready for writes, flushing buffer')
|
|
self.last_activity = time.time()
|
|
|
|
# TODO(abhinavsingh): This hook could just reside within server recv block
|
|
# instead of invoking when flushed to client.
|
|
#
|
|
# Invoke plugin.on_response_chunk
|
|
chunk = self.client.buffer
|
|
for plugin in self.plugins.values():
|
|
chunk = plugin.on_response_chunk(chunk)
|
|
if chunk is None:
|
|
break
|
|
|
|
try:
|
|
# Call super() for client flush
|
|
teardown = super().handle_writables(writables)
|
|
if teardown:
|
|
return True
|
|
except BrokenPipeError:
|
|
logger.error(
|
|
'BrokenPipeError when flushing buffer for client',
|
|
)
|
|
return True
|
|
except OSError:
|
|
logger.error('OSError when flushing buffer to client')
|
|
return True
|
|
return False
|
|
|
|
def handle_readables(self, readables: Readables) -> bool:
|
|
if self.client.connection in readables:
|
|
logger.debug('Client is ready for reads, reading')
|
|
self.last_activity = time.time()
|
|
try:
|
|
teardown = super().handle_readables(readables)
|
|
if teardown:
|
|
return teardown
|
|
except ssl.SSLWantReadError: # Try again later
|
|
logger.warning(
|
|
'SSLWantReadError encountered while reading from client, will retry ...',
|
|
)
|
|
return False
|
|
except socket.error as e:
|
|
if e.errno == errno.ECONNRESET:
|
|
logger.warning('%r' % e)
|
|
else:
|
|
logger.exception(
|
|
'Exception while receiving from %s connection %r with reason %r' %
|
|
(self.client.tag, self.client.connection, e),
|
|
)
|
|
return True
|
|
return False
|
|
|
|
##
|
|
# run() is here to maintain backward compatibility for threaded mode
|
|
##
|
|
|
|
def run(self) -> None:
|
|
"""run() method is not used when in --threadless mode.
|
|
|
|
This is here just to maintain backward compatibility with threaded mode.
|
|
"""
|
|
try:
|
|
self.initialize()
|
|
while True:
|
|
# Teardown if client buffer is empty and connection is inactive
|
|
if self.is_inactive():
|
|
logger.debug(
|
|
'Client buffer is empty and maximum inactivity has reached '
|
|
'between client and server connection, tearing down...',
|
|
)
|
|
break
|
|
teardown = self._run_once()
|
|
if teardown:
|
|
break
|
|
except KeyboardInterrupt: # pragma: no cover
|
|
pass
|
|
except ssl.SSLError as e:
|
|
logger.exception('ssl.SSLError', exc_info=e)
|
|
except Exception as e:
|
|
logger.exception(
|
|
'Exception while handling connection %r' %
|
|
self.client.connection, exc_info=e,
|
|
)
|
|
finally:
|
|
self.shutdown()
|
|
|
|
##
|
|
# Internal methods
|
|
##
|
|
|
|
def _encryption_enabled(self) -> bool:
|
|
return self.flags.keyfile is not None and \
|
|
self.flags.certfile is not None
|
|
|
|
def _optionally_wrap_socket(
|
|
self, conn: socket.socket,
|
|
) -> Union[ssl.SSLSocket, socket.socket]:
|
|
"""Attempts to wrap accepted client connection using provided certificates.
|
|
|
|
Shutdown and closes client connection upon error.
|
|
"""
|
|
if self._encryption_enabled():
|
|
assert self.flags.keyfile and self.flags.certfile
|
|
# TODO(abhinavsingh): Insecure TLS versions must not be accepted by default
|
|
conn = wrap_socket(conn, self.flags.keyfile, self.flags.certfile)
|
|
return conn
|
|
|
|
@contextlib.contextmanager
|
|
def _selected_events(self) -> SelectedEventsGeneratorType:
|
|
assert self.selector
|
|
events = self.get_events()
|
|
for fd in events:
|
|
self.selector.register(fd, events[fd])
|
|
ev = self.selector.select(timeout=1)
|
|
readables = []
|
|
writables = []
|
|
for key, mask in ev:
|
|
if mask & selectors.EVENT_READ:
|
|
readables.append(key.fileobj)
|
|
if mask & selectors.EVENT_WRITE:
|
|
writables.append(key.fileobj)
|
|
yield (readables, writables)
|
|
for fd in events:
|
|
self.selector.unregister(fd)
|
|
|
|
def _run_once(self) -> bool:
|
|
with self._selected_events() as (readables, writables):
|
|
teardown = self.handle_events(readables, writables)
|
|
if teardown:
|
|
return True
|
|
return False
|
|
|
|
def _flush(self) -> None:
|
|
assert self.selector
|
|
if not self.client.has_buffer():
|
|
return
|
|
try:
|
|
self.selector.register(
|
|
self.client.connection,
|
|
selectors.EVENT_WRITE,
|
|
)
|
|
while self.client.has_buffer():
|
|
ev: List[
|
|
Tuple[selectors.SelectorKey, int]
|
|
] = self.selector.select(timeout=1)
|
|
if len(ev) == 0:
|
|
continue
|
|
self.client.flush()
|
|
except BrokenPipeError:
|
|
pass
|
|
finally:
|
|
self.selector.unregister(self.client.connection)
|
|
|
|
def _connection_inactive_for(self) -> float:
|
|
return time.time() - self.last_activity
|