proxy.py/proxy/http/handler.py

418 lines
16 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import socket
import selectors
import ssl
import time
import contextlib
import errno
import logging
from abc import ABC, abstractmethod
from typing import Tuple, List, Union, Optional, Generator, Dict
from uuid import UUID
from .parser import HttpParser, httpParserStates, httpParserTypes
from .exception import HttpProtocolException
from ..common.flags import Flags
from ..common.types import HasFileno
from ..core.threadless import ThreadlessWork
from ..core.event import EventQueue
from ..core.connection import TcpClientConnection
logger = logging.getLogger(__name__)
class HttpProtocolHandlerPlugin(ABC):
"""Base HttpProtocolHandler Plugin class.
NOTE: This is an internal plugin and in most cases only useful for core contributors.
If you are looking for proxy server plugins see `<proxy.HttpProxyBasePlugin>`.
Implements various lifecycle events for an accepted client connection.
Following events are of interest:
1. Client Connection Accepted
A new plugin instance is created per accepted client connection.
Add your logic within __init__ constructor for any per connection setup.
2. Client Request Chunk Received
on_client_data is called for every chunk of data sent by the client.
3. Client Request Complete
on_request_complete is called once client request has completed.
4. Server Response Chunk Received
on_response_chunk is called for every chunk received from the server.
5. Client Connection Closed
Add your logic within `on_client_connection_close` for any per connection teardown.
"""
def __init__(
self,
uid: UUID,
flags: Flags,
client: TcpClientConnection,
request: HttpParser,
event_queue: EventQueue):
self.uid: UUID = uid
self.flags: Flags = flags
self.client: TcpClientConnection = client
self.request: HttpParser = request
self.event_queue = event_queue
super().__init__()
def name(self) -> str:
"""A unique name for your plugin.
Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__
@abstractmethod
def get_descriptors(
self) -> Tuple[List[socket.socket], List[socket.socket]]:
return [], [] # pragma: no cover
@abstractmethod
def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool:
return False # pragma: no cover
@abstractmethod
def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool:
return False # pragma: no cover
@abstractmethod
def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
return raw # pragma: no cover
@abstractmethod
def on_request_complete(self) -> Union[socket.socket, bool]:
"""Called right after client request parser has reached COMPLETE state."""
return False # pragma: no cover
@abstractmethod
def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]:
"""Handle data chunks as received from the server.
Return optionally modified chunk to return back to client."""
return chunk # pragma: no cover
@abstractmethod
def on_client_connection_close(self) -> None:
pass # pragma: no cover
class HttpProtocolHandler(ThreadlessWork):
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
Accepts `Client` connection object and manages HttpProtocolHandlerPlugin invocations.
"""
def __init__(self, client: TcpClientConnection,
flags: Optional[Flags] = None,
event_queue: Optional[EventQueue] = None,
uid: Optional[UUID] = None):
super().__init__(client, flags, event_queue, uid)
self.start_time: float = time.time()
self.last_activity: float = self.start_time
self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER)
self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER)
self.selector = selectors.DefaultSelector()
self.client: TcpClientConnection = client
self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {}
def initialize(self) -> None:
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
conn = self.optionally_wrap_socket(self.client.connection)
conn.setblocking(False)
if self.flags.encryption_enabled():
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
instance = klass(
self.uid,
self.flags,
self.client,
self.request,
self.event_queue)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.client.connection)
def is_inactive(self) -> bool:
if not self.client.has_buffer() and \
self.connection_inactive_for() > self.flags.timeout:
return True
return False
def get_events(self) -> Dict[socket.socket, int]:
events: Dict[socket.socket, int] = {
self.client.connection: selectors.EVENT_READ
}
if self.client.has_buffer():
events[self.client.connection] |= selectors.EVENT_WRITE
# HttpProtocolHandlerPlugin.get_descriptors
for plugin in self.plugins.values():
plugin_read_desc, plugin_write_desc = plugin.get_descriptors()
for r in plugin_read_desc:
if r not in events:
events[r] = selectors.EVENT_READ
else:
events[r] |= selectors.EVENT_READ
for w in plugin_write_desc:
if w not in events:
events[w] = selectors.EVENT_WRITE
else:
events[w] |= selectors.EVENT_WRITE
return events
def handle_events(
self,
readables: List[Union[int, HasFileno]],
writables: List[Union[int, HasFileno]]) -> bool:
"""Returns True if proxy must teardown."""
# Flush buffer for ready to write sockets
teardown = self.handle_writables(writables)
if teardown:
return True
# Invoke plugin.write_to_descriptors
for plugin in self.plugins.values():
teardown = plugin.write_to_descriptors(writables)
if teardown:
return True
# Read from ready to read sockets
teardown = self.handle_readables(readables)
if teardown:
return True
# Invoke plugin.read_from_descriptors
for plugin in self.plugins.values():
teardown = plugin.read_from_descriptors(readables)
if teardown:
return True
return False
def shutdown(self) -> None:
try:
# Flush pending buffer if any
self.flush()
# Invoke plugin.on_client_connection_close
for plugin in self.plugins.values():
plugin.on_client_connection_close()
logger.debug(
'Closing client connection %r '
'at address %r has buffer %s' %
(self.client.connection, self.client.addr, self.client.has_buffer()))
conn = self.client.connection
# Unwrap if wrapped before shutdown.
if self.flags.encryption_enabled() and \
isinstance(self.client.connection, ssl.SSLSocket):
conn = self.client.connection.unwrap()
conn.shutdown(socket.SHUT_WR)
logger.debug('Client connection shutdown successful')
except OSError:
pass
finally:
self.client.connection.close()
logger.debug('Client connection closed')
super().shutdown()
def optionally_wrap_socket(
self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]:
"""Attempts to wrap accepted client connection using provided certificates.
Shutdown and closes client connection upon error.
"""
if self.flags.encryption_enabled():
ctx = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
ctx.verify_mode = ssl.CERT_NONE
assert self.flags.keyfile and self.flags.certfile
ctx.load_cert_chain(
certfile=self.flags.certfile,
keyfile=self.flags.keyfile)
conn = ctx.wrap_socket(
conn,
server_side=True,
)
return conn
def connection_inactive_for(self) -> float:
return time.time() - self.last_activity
def flush(self) -> None:
if not self.client.has_buffer():
return
try:
self.selector.register(
self.client.connection,
selectors.EVENT_WRITE)
while self.client.has_buffer():
ev: List[Tuple[selectors.SelectorKey, int]
] = self.selector.select(timeout=1)
if len(ev) == 0:
continue
self.client.flush()
except BrokenPipeError:
pass
finally:
self.selector.unregister(self.client.connection)
def handle_writables(self, writables: List[Union[int, HasFileno]]) -> bool:
if self.client.has_buffer() and self.client.connection in writables:
logger.debug('Client is ready for writes, flushing buffer')
self.last_activity = time.time()
# TODO(abhinavsingh): This hook could just reside within server recv block
# instead of invoking when flushed to client.
# Invoke plugin.on_response_chunk
chunk = self.client.buffer
for plugin in self.plugins.values():
chunk = plugin.on_response_chunk(chunk)
if chunk is None:
break
try:
self.client.flush()
except OSError:
logger.error('OSError when flushing buffer to client')
return True
except BrokenPipeError:
logger.error(
'BrokenPipeError when flushing buffer for client')
return True
return False
def handle_readables(self, readables: List[Union[int, HasFileno]]) -> bool:
if self.client.connection in readables:
logger.debug('Client is ready for reads, reading')
self.last_activity = time.time()
try:
client_data = self.client.recv(self.flags.client_recvbuf_size)
except ssl.SSLWantReadError: # Try again later
logger.warning(
'SSLWantReadError encountered while reading from client, will retry ...')
return False
except socket.error as e:
if e.errno == errno.ECONNRESET:
logger.warning('%r' % e)
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.client.tag, self.client.connection, e))
return True
if client_data is None:
logger.debug('Client closed connection, tearing down...')
self.client.closed = True
return True
try:
# HttpProtocolHandlerPlugin.on_client_data
# Can raise HttpProtocolException to teardown the connection
plugin_index = 0
plugins = list(self.plugins.values())
while plugin_index < len(plugins) and client_data:
client_data = plugins[plugin_index].on_client_data(
client_data)
if client_data is None:
break
plugin_index += 1
# Don't parse request any further after 1st request has completed.
# This specially does happen for pipeline requests.
# Plugins can utilize on_client_data for such cases and
# apply custom logic to handle request data sent after 1st
# valid request.
if client_data and self.request.state != httpParserStates.COMPLETE:
# Parse http request
# TODO(abhinavsingh): Remove .tobytes after parser is
# memoryview compliant
self.request.parse(client_data.tobytes())
if self.request.state == httpParserStates.COMPLETE:
# Invoke plugin.on_request_complete
for plugin in self.plugins.values():
upgraded_sock = plugin.on_request_complete()
if isinstance(upgraded_sock, ssl.SSLSocket):
logger.debug(
'Updated client conn to %s', upgraded_sock)
self.client._conn = upgraded_sock
for plugin_ in self.plugins.values():
if plugin_ != plugin:
plugin_.client._conn = upgraded_sock
elif isinstance(upgraded_sock, bool) and upgraded_sock is True:
return True
except HttpProtocolException as e:
logger.debug(
'HttpProtocolException type raised')
response = e.response(self.request)
if response:
self.client.queue(response)
return True
return False
@contextlib.contextmanager
def selected_events(self) -> \
Generator[Tuple[List[Union[int, HasFileno]],
List[Union[int, HasFileno]]],
None, None]:
events = self.get_events()
for fd in events:
self.selector.register(fd, events[fd])
ev = self.selector.select(timeout=1)
readables = []
writables = []
for key, mask in ev:
if mask & selectors.EVENT_READ:
readables.append(key.fileobj)
if mask & selectors.EVENT_WRITE:
writables.append(key.fileobj)
yield (readables, writables)
for fd in events.keys():
self.selector.unregister(fd)
def run_once(self) -> bool:
with self.selected_events() as (readables, writables):
teardown = self.handle_events(readables, writables)
if teardown:
return True
return False
def run(self) -> None:
try:
self.initialize()
while True:
# Teardown if client buffer is empty and connection is inactive
if self.is_inactive():
logger.debug(
'Client buffer is empty and maximum inactivity has reached '
'between client and server connection, tearing down...')
break
teardown = self.run_once()
if teardown:
break
except KeyboardInterrupt: # pragma: no cover
pass
except ssl.SSLError as e:
logger.exception('ssl.SSLError', exc_info=e)
except Exception as e:
logger.exception(
'Exception while handling connection %r' %
self.client.connection, exc_info=e)
finally:
self.shutdown()