246 lines
8.3 KiB
Python
246 lines
8.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import os
|
|
import uuid
|
|
import socket
|
|
import logging
|
|
import asyncio
|
|
import selectors
|
|
import contextlib
|
|
import multiprocessing
|
|
from multiprocessing import connection
|
|
from multiprocessing.reduction import recv_handle
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Optional, Tuple, List, Union, Generator, Any, Type
|
|
|
|
from .event import EventQueue, eventNames
|
|
|
|
from ..common.flags import Flags
|
|
from ..common.types import HasFileno
|
|
from ..common.constants import DEFAULT_TIMEOUT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ThreadlessWork(ABC):
|
|
"""Implement ThreadlessWork to hook into the event loop provided by Threadless process."""
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
fileno: int,
|
|
addr: Tuple[str, int],
|
|
flags: Optional[Flags],
|
|
event_queue: Optional[EventQueue] = None,
|
|
uid: Optional[str] = None) -> None:
|
|
self.fileno = fileno
|
|
self.addr = addr
|
|
self.flags = flags if flags else Flags()
|
|
|
|
self.event_queue = event_queue
|
|
self.uid: str = uid if uid is not None else uuid.uuid4().hex
|
|
|
|
@abstractmethod
|
|
def initialize(self) -> None:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def is_inactive(self) -> bool:
|
|
return False # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def get_events(self) -> Dict[socket.socket, int]:
|
|
return {} # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def handle_events(
|
|
self,
|
|
readables: List[Union[int, HasFileno]],
|
|
writables: List[Union[int, HasFileno]]) -> bool:
|
|
"""Return True to shutdown work."""
|
|
return False # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def run(self) -> None:
|
|
pass
|
|
|
|
def publish_event(
|
|
self,
|
|
event_name: int,
|
|
event_payload: Dict[str, Any],
|
|
publisher_id: Optional[str] = None) -> None:
|
|
if not self.flags.enable_events:
|
|
return
|
|
assert self.event_queue
|
|
self.event_queue.publish(
|
|
self.uid,
|
|
event_name,
|
|
event_payload,
|
|
publisher_id
|
|
)
|
|
|
|
def shutdown(self) -> None:
|
|
"""Must close any opened resources and call super().shutdown()."""
|
|
self.publish_event(
|
|
event_name=eventNames.WORK_FINISHED,
|
|
event_payload={},
|
|
publisher_id=self.__class__.__name__
|
|
)
|
|
|
|
|
|
class Threadless(multiprocessing.Process):
|
|
"""Threadless provides an event loop. Use it by implementing Threadless class.
|
|
|
|
When --threadless option is enabled, each Acceptor process also
|
|
spawns one Threadless process. And instead of spawning new thread
|
|
for each accepted client connection, Acceptor process sends
|
|
accepted client connection to Threadless process over a pipe.
|
|
|
|
HttpProtocolHandler implements ThreadlessWork class and hooks into the
|
|
event loop provided by Threadless.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
client_queue: connection.Connection,
|
|
flags: Flags,
|
|
work_klass: Type[ThreadlessWork],
|
|
event_queue: Optional[EventQueue] = None) -> None:
|
|
super().__init__()
|
|
self.client_queue = client_queue
|
|
self.flags = flags
|
|
self.work_klass = work_klass
|
|
self.event_queue = event_queue
|
|
|
|
self.running = multiprocessing.Event()
|
|
self.works: Dict[int, ThreadlessWork] = {}
|
|
self.selector: Optional[selectors.DefaultSelector] = None
|
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
|
|
|
@contextlib.contextmanager
|
|
def selected_events(self) -> Generator[Tuple[List[Union[int, HasFileno]],
|
|
List[Union[int, HasFileno]]],
|
|
None, None]:
|
|
events: Dict[socket.socket, int] = {}
|
|
for work in self.works.values():
|
|
events.update(work.get_events())
|
|
assert self.selector is not None
|
|
for fd in events:
|
|
self.selector.register(fd, events[fd])
|
|
ev = self.selector.select(timeout=1)
|
|
readables = []
|
|
writables = []
|
|
for key, mask in ev:
|
|
if mask & selectors.EVENT_READ:
|
|
readables.append(key.fileobj)
|
|
if mask & selectors.EVENT_WRITE:
|
|
writables.append(key.fileobj)
|
|
yield (readables, writables)
|
|
for fd in events.keys():
|
|
self.selector.unregister(fd)
|
|
|
|
async def handle_events(
|
|
self, fileno: int,
|
|
readables: List[Union[int, HasFileno]],
|
|
writables: List[Union[int, HasFileno]]) -> bool:
|
|
return self.works[fileno].handle_events(readables, writables)
|
|
|
|
# TODO: Use correct future typing annotations
|
|
async def wait_for_tasks(
|
|
self, tasks: Dict[int, Any]) -> None:
|
|
for work_id in tasks:
|
|
# TODO: Resolving one handle_events here can block resolution of
|
|
# other tasks
|
|
try:
|
|
teardown = await asyncio.wait_for(tasks[work_id], DEFAULT_TIMEOUT)
|
|
if teardown:
|
|
self.cleanup(work_id)
|
|
except asyncio.TimeoutError:
|
|
self.cleanup(work_id)
|
|
|
|
def accept_client(self) -> None:
|
|
addr = self.client_queue.recv()
|
|
fileno = recv_handle(self.client_queue)
|
|
self.works[fileno] = self.work_klass(
|
|
fileno=fileno,
|
|
addr=addr,
|
|
flags=self.flags,
|
|
event_queue=self.event_queue
|
|
)
|
|
self.works[fileno].publish_event(
|
|
event_name=eventNames.WORK_STARTED,
|
|
event_payload={'fileno': fileno, 'addr': addr},
|
|
publisher_id=self.__class__.__name__
|
|
)
|
|
try:
|
|
self.works[fileno].initialize()
|
|
except Exception as e:
|
|
logger.exception(
|
|
'Exception occurred during initialization',
|
|
exc_info=e)
|
|
self.cleanup(fileno)
|
|
|
|
def cleanup_inactive(self) -> None:
|
|
inactive_works: List[int] = []
|
|
for work_id in self.works:
|
|
if self.works[work_id].is_inactive():
|
|
inactive_works.append(work_id)
|
|
for work_id in inactive_works:
|
|
self.cleanup(work_id)
|
|
|
|
def cleanup(self, work_id: int) -> None:
|
|
# TODO: HttpProtocolHandler.shutdown can call flush which may block
|
|
self.works[work_id].shutdown()
|
|
del self.works[work_id]
|
|
os.close(work_id)
|
|
|
|
def run_once(self) -> None:
|
|
assert self.loop is not None
|
|
with self.selected_events() as (readables, writables):
|
|
if len(readables) == 0 and len(writables) == 0:
|
|
# Remove and shutdown inactive connections
|
|
self.cleanup_inactive()
|
|
return
|
|
# Note that selector from now on is idle,
|
|
# until all the logic below completes.
|
|
#
|
|
# Invoke Threadless.handle_events
|
|
# TODO: Only send readable / writables that client originally
|
|
# registered.
|
|
tasks = {}
|
|
for fileno in self.works:
|
|
tasks[fileno] = self.loop.create_task(
|
|
self.handle_events(fileno, readables, writables))
|
|
# Accepted client connection from Acceptor
|
|
if self.client_queue in readables:
|
|
self.accept_client()
|
|
# Wait for Threadless.handle_events to complete
|
|
self.loop.run_until_complete(self.wait_for_tasks(tasks))
|
|
# Remove and shutdown inactive connections
|
|
self.cleanup_inactive()
|
|
|
|
def run(self) -> None:
|
|
try:
|
|
self.selector = selectors.DefaultSelector()
|
|
self.selector.register(self.client_queue, selectors.EVENT_READ)
|
|
self.loop = asyncio.get_event_loop()
|
|
while not self.running.is_set():
|
|
self.run_once()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
assert self.selector is not None
|
|
self.selector.unregister(self.client_queue)
|
|
self.client_queue.close()
|
|
assert self.loop is not None
|
|
self.loop.close()
|