# -*- coding: utf-8 -*- """ proxy.py ~~~~~~~~ ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on Network monitoring, controls & Application development, testing, debugging. :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ import os import ssl import socket import logging import asyncio import argparse import selectors import multiprocessing from abc import abstractmethod, ABC from typing import Dict, Optional, Tuple, List, Set, Generic, TypeVar, Union from ...common.logger import Logger from ...common.types import Readables, Writables from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT from ..connection import TcpClientConnection from ..event import eventNames, EventQueue from .work import Work T = TypeVar('T') logger = logging.getLogger(__name__) class Threadless(ABC, Generic[T]): """Work executor base class. Threadless provides an event loop, which is shared across multiple :class:`~proxy.core.acceptor.work.Work` instances to handle work. Threadless takes input a `work_klass` and an `event_queue`. `work_klass` must conform to the :class:`~proxy.core.acceptor.work.Work` protocol. Work is received over the `event_queue`. When a work is accepted, threadless creates a new instance of `work_klass`. Threadless will then invoke necessary lifecycle of the :class:`~proxy.core.acceptor.work.Work` protocol, allowing `work_klass` implementation to handle the assigned work. Example, :class:`~proxy.core.base.tcp_server.BaseTcpServerHandler` implements :class:`~proxy.core.acceptor.work.Work` protocol. It expects a client connection as work payload and hooks into the threadless event loop to handle the client connection. """ def __init__( self, work_queue: T, flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, ) -> None: super().__init__() self.work_queue = work_queue self.flags = flags self.event_queue = event_queue self.running = multiprocessing.Event() self.works: Dict[int, Work] = {} self.selector: Optional[selectors.DefaultSelector] = None # If we remove single quotes for typing hint below, # runtime exceptions will occur for < Python 3.9. # # Ref https://github.com/abhinavsingh/proxy.py/runs/4279055360?check_suite_focus=true self.unfinished: Set['asyncio.Task[bool]'] = set() self.registered_events_by_work_ids: Dict[ # work_id int, # fileno, mask Dict[int, int], ] = {} self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT @property @abstractmethod def loop(self) -> Optional[asyncio.AbstractEventLoop]: raise NotImplementedError() @abstractmethod def receive_from_work_queue(self) -> bool: """Work queue is ready to receive new work. Receive it and call ``work_on_tcp_conn``. Return True to tear down the loop.""" raise NotImplementedError() @abstractmethod def work_queue_fileno(self) -> Optional[int]: """If work queue must be selected before calling ``receive_from_work_queue`` then implementation must return work queue fd.""" raise NotImplementedError() def close_work_queue(self) -> None: """Only called if ``work_queue_fileno`` returns an integer. If an fd is select-able for work queue, make sure to close the work queue fd now.""" pass # pragma: no cover def work_on_tcp_conn( self, fileno: int, addr: Optional[Tuple[str, int]] = None, conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None, ) -> None: conn = conn or socket.fromfd( fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, type=socket.SOCK_STREAM, ) self.works[fileno] = self.flags.work_klass( TcpClientConnection( conn=conn, addr=addr, ), flags=self.flags, event_queue=self.event_queue, uid=fileno, ) self.works[fileno].publish_event( event_name=eventNames.WORK_STARTED, event_payload={'fileno': fileno, 'addr': addr}, publisher_id=self.__class__.__name__, ) try: self.works[fileno].initialize() except Exception as e: logger.exception( 'Exception occurred during initialization', exc_info=e, ) self._cleanup(fileno) async def _update_work_events(self, work_id: int) -> None: assert self.selector is not None worker_events = await self.works[work_id].get_events() # NOTE: Current assumption is that multiple works will not # be interested in the same fd. Descriptors of interests # returned by work must be unique. # # TODO: Ideally we must diff and unregister socks not # returned of interest within current _select_events call # but exists in the registered_socks_by_work_ids registry. for fileno in worker_events: if work_id not in self.registered_events_by_work_ids: self.registered_events_by_work_ids[work_id] = {} mask = worker_events[fileno] if fileno in self.registered_events_by_work_ids[work_id]: oldmask = self.registered_events_by_work_ids[work_id][fileno] if mask != oldmask: self.selector.modify( fileno, events=mask, data=work_id, ) self.registered_events_by_work_ids[work_id][fileno] = mask # logger.debug( # 'fd#{0} modified for mask#{1} by work#{2}'.format( # fileno, mask, work_id, # ), # ) # else: # logger.info( # 'fd#{0} by work#{1} not modified'.format(fileno, work_id)) else: # Can throw ValueError: Invalid file descriptor: -1 # # A guard within Work classes may not help here due to # asynchronous nature. Hence, threadless will handle # ValueError exceptions raised by selector.register # for invalid fd. self.selector.register( fileno, events=mask, data=work_id, ) self.registered_events_by_work_ids[work_id][fileno] = mask # logger.debug( # 'fd#{0} registered for mask#{1} by work#{2}'.format( # fileno, mask, work_id, # ), # ) async def _update_selector(self) -> None: assert self.selector is not None unfinished_work_ids = set() for task in self.unfinished: unfinished_work_ids.add(task._work_id) # type: ignore for work_id in self.works: # We don't want to invoke work objects which haven't # yet finished their previous task if work_id in unfinished_work_ids: continue await self._update_work_events(work_id) async def _selected_events(self) -> Tuple[ Dict[int, Tuple[Readables, Writables]], bool, ]: """For each work, collects events that they are interested in. Calls select for events of interest. Returns a 2-tuple containing a dictionary and boolean. Dictionary keys are work IDs and values are 2-tuple containing ready readables & writables. Returned boolean value indicates whether there is a newly accepted work waiting to be received and queued for processing. This is only applicable when :class:`~proxy.core.acceptor.threadless.Threadless.work_queue_fileno` returns a valid fd. """ assert self.selector is not None await self._update_selector() events = self.selector.select( timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT, ) # Keys are work_id and values are 2-tuple indicating # readables & writables that work_id is interested in # and are ready for IO. work_by_ids: Dict[int, Tuple[Readables, Writables]] = {} new_work_available = False wqfileno = self.work_queue_fileno() if wqfileno is None: # When ``work_queue_fileno`` returns None, # always return True for the boolean value. new_work_available = True for key, mask in events: if not new_work_available and wqfileno is not None and key.fileobj == wqfileno: assert mask & selectors.EVENT_READ new_work_available = True continue if key.data not in work_by_ids: work_by_ids[key.data] = ([], []) if mask & selectors.EVENT_READ: work_by_ids[key.data][0].append(key.fileobj) if mask & selectors.EVENT_WRITE: work_by_ids[key.data][1].append(key.fileobj) return (work_by_ids, new_work_available) async def _wait_for_tasks(self) -> Set['asyncio.Task[bool]']: finished, self.unfinished = await asyncio.wait( self.unfinished, timeout=self.wait_timeout, return_when=asyncio.FIRST_COMPLETED, ) return finished # noqa: WPS331 def _cleanup_inactive(self) -> None: inactive_works: List[int] = [] for work_id in self.works: if self.works[work_id].is_inactive(): inactive_works.append(work_id) for work_id in inactive_works: self._cleanup(work_id) # TODO: HttpProtocolHandler.shutdown can call flush which may block def _cleanup(self, work_id: int) -> None: if work_id in self.registered_events_by_work_ids: assert self.selector for fileno in self.registered_events_by_work_ids[work_id]: logger.debug( 'fd#{0} unregistered by work#{1}'.format( fileno, work_id, ), ) self.selector.unregister(fileno) self.registered_events_by_work_ids[work_id].clear() del self.registered_events_by_work_ids[work_id] self.works[work_id].shutdown() del self.works[work_id] if self.work_queue_fileno() is not None: os.close(work_id) def _create_tasks( self, work_by_ids: Dict[int, Tuple[Readables, Writables]], ) -> Set['asyncio.Task[bool]']: assert self.loop tasks: Set['asyncio.Task[bool]'] = set() for work_id in work_by_ids: task = self.loop.create_task( self.works[work_id].handle_events(*work_by_ids[work_id]), ) task._work_id = work_id # type: ignore[attr-defined] # task.set_name(work_id) tasks.add(task) return tasks async def _run_once(self) -> bool: assert self.loop is not None work_by_ids, new_work_available = await self._selected_events() # Accept new work if available # # TODO: We must use a work klass to handle # client_queue fd itself a.k.a. accept_client # will become handle_readables. if new_work_available: teardown = self.receive_from_work_queue() if teardown: return teardown if len(work_by_ids) == 0: return False # Invoke Threadless.handle_events self.unfinished.update(self._create_tasks(work_by_ids)) # logger.debug('Executing {0} works'.format(len(self.unfinished))) # Cleanup finished tasks for task in await self._wait_for_tasks(): # Checking for result can raise exception e.g. # CancelledError, InvalidStateError or an exception # from underlying task e.g. TimeoutError. teardown = False work_id = task._work_id # type: ignore try: teardown = task.result() finally: if teardown: self._cleanup(work_id) # self.cleanup(int(task.get_name())) # logger.debug( # 'Done executing works, {0} pending, {1} registered'.format( # len(self.unfinished), len(self.registered_events_by_work_ids), # ), # ) return False async def _run_forever(self) -> None: tick = 0 try: while True: if await self._run_once(): break # Check for inactive and shutdown signal elapsed = tick * \ (DEFAULT_SELECTOR_SELECT_TIMEOUT + self.wait_timeout) if elapsed >= self.cleanup_inactive_timeout: self._cleanup_inactive() if self.running.is_set(): break tick = 0 tick += 1 except KeyboardInterrupt: pass finally: if self.loop: self.loop.stop() def run(self) -> None: Logger.setup( self.flags.log_file, self.flags.log_level, self.flags.log_format, ) wqfileno = self.work_queue_fileno() try: self.selector = selectors.DefaultSelector() if wqfileno is not None: self.selector.register( wqfileno, selectors.EVENT_READ, data=wqfileno, ) assert self.loop # logger.debug('Working on {0} works'.format(len(self.works))) self.loop.create_task(self._run_forever()) self.loop.run_forever() except KeyboardInterrupt: pass finally: assert self.selector is not None if wqfileno is not None: self.selector.unregister(wqfileno) self.close_work_queue() assert self.loop is not None self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close()