240 lines
8.7 KiB
Python
240 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import logging
|
|
import multiprocessing
|
|
import selectors
|
|
import socket
|
|
import threading
|
|
# import time
|
|
from multiprocessing import connection
|
|
from multiprocessing.reduction import send_handle, recv_handle
|
|
from typing import List, Optional, Type, Tuple
|
|
|
|
from .threadless import ThreadlessWork, Threadless
|
|
from .event import EventQueue, EventDispatcher, eventNames
|
|
from ..common.flags import Flags
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AcceptorPool:
|
|
"""AcceptorPool.
|
|
|
|
Pre-spawns worker processes to utilize all cores available on the system. Server socket connection is
|
|
dispatched over a pipe to workers. Each worker accepts incoming client request and spawns a
|
|
separate thread to handle the client request.
|
|
"""
|
|
|
|
def __init__(self, flags: Flags, work_klass: Type[ThreadlessWork]) -> None:
|
|
self.flags = flags
|
|
self.socket: Optional[socket.socket] = None
|
|
self.acceptors: List[Acceptor] = []
|
|
self.work_queues: List[connection.Connection] = []
|
|
self.work_klass = work_klass
|
|
|
|
self.event_queue: Optional[EventQueue] = None
|
|
self.event_dispatcher: Optional[EventDispatcher] = None
|
|
self.event_dispatcher_thread: Optional[threading.Thread] = None
|
|
self.event_dispatcher_shutdown: Optional[threading.Event] = None
|
|
if self.flags.enable_events:
|
|
self.event_queue = EventQueue()
|
|
|
|
def listen(self) -> None:
|
|
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
self.socket.bind((str(self.flags.hostname), self.flags.port))
|
|
self.socket.listen(self.flags.backlog)
|
|
self.socket.setblocking(False)
|
|
logger.info(
|
|
'Listening on %s:%d' %
|
|
(self.flags.hostname, self.flags.port))
|
|
|
|
def start_workers(self) -> None:
|
|
"""Start worker processes."""
|
|
for acceptor_id in range(self.flags.num_workers):
|
|
work_queue = multiprocessing.Pipe()
|
|
acceptor = Acceptor(
|
|
idd=acceptor_id,
|
|
work_queue=work_queue[1],
|
|
flags=self.flags,
|
|
work_klass=self.work_klass,
|
|
event_queue=self.event_queue
|
|
)
|
|
acceptor.start()
|
|
logger.debug(
|
|
'Started acceptor#%d process %d',
|
|
acceptor_id,
|
|
acceptor.pid)
|
|
self.acceptors.append(acceptor)
|
|
self.work_queues.append(work_queue[0])
|
|
logger.info('Started %d workers' % self.flags.num_workers)
|
|
|
|
def start_event_dispatcher(self) -> None:
|
|
self.event_dispatcher_shutdown = threading.Event()
|
|
assert self.event_dispatcher_shutdown
|
|
assert self.event_queue
|
|
self.event_dispatcher = EventDispatcher(
|
|
shutdown=self.event_dispatcher_shutdown,
|
|
event_queue=self.event_queue
|
|
)
|
|
self.event_dispatcher_thread = threading.Thread(
|
|
target=self.event_dispatcher.run
|
|
)
|
|
self.event_dispatcher_thread.start()
|
|
logger.debug('Thread ID: %d', self.event_dispatcher_thread.ident)
|
|
|
|
def shutdown(self) -> None:
|
|
logger.info('Shutting down %d workers' % self.flags.num_workers)
|
|
for acceptor in self.acceptors:
|
|
acceptor.running.set()
|
|
if self.flags.enable_events:
|
|
assert self.event_dispatcher_shutdown
|
|
assert self.event_dispatcher_thread
|
|
self.event_dispatcher_shutdown.set()
|
|
self.event_dispatcher_thread.join()
|
|
logger.debug(
|
|
'Shutdown of global event dispatcher thread %d successful',
|
|
self.event_dispatcher_thread.ident)
|
|
for acceptor in self.acceptors:
|
|
acceptor.join()
|
|
logger.debug('Acceptors shutdown')
|
|
|
|
def setup(self) -> None:
|
|
"""Listen on port, setup workers and pass server socket to workers."""
|
|
self.listen()
|
|
if self.flags.enable_events:
|
|
logger.info('Core Event enabled')
|
|
self.start_event_dispatcher()
|
|
self.start_workers()
|
|
|
|
# Send server socket to all acceptor processes.
|
|
assert self.socket is not None
|
|
for index in range(self.flags.num_workers):
|
|
send_handle(
|
|
self.work_queues[index],
|
|
self.socket.fileno(),
|
|
self.acceptors[index].pid
|
|
)
|
|
self.work_queues[index].close()
|
|
self.socket.close()
|
|
|
|
|
|
class Acceptor(multiprocessing.Process):
|
|
"""Socket client acceptor.
|
|
|
|
Accepts client connection over received server socket handle and
|
|
starts a new work thread.
|
|
"""
|
|
|
|
lock = multiprocessing.Lock()
|
|
|
|
def __init__(
|
|
self,
|
|
idd: int,
|
|
work_queue: connection.Connection,
|
|
flags: Flags,
|
|
work_klass: Type[ThreadlessWork],
|
|
event_queue: Optional[EventQueue] = None) -> None:
|
|
super().__init__()
|
|
self.idd = idd
|
|
self.work_queue: connection.Connection = work_queue
|
|
self.flags = flags
|
|
self.work_klass = work_klass
|
|
self.event_queue = event_queue
|
|
|
|
self.running = multiprocessing.Event()
|
|
self.selector: Optional[selectors.DefaultSelector] = None
|
|
self.sock: Optional[socket.socket] = None
|
|
self.threadless_process: Optional[Threadless] = None
|
|
self.threadless_client_queue: Optional[connection.Connection] = None
|
|
|
|
def start_threadless_process(self) -> None:
|
|
pipe = multiprocessing.Pipe()
|
|
self.threadless_client_queue = pipe[0]
|
|
self.threadless_process = Threadless(
|
|
client_queue=pipe[1],
|
|
flags=self.flags,
|
|
work_klass=self.work_klass,
|
|
event_queue=self.event_queue
|
|
)
|
|
self.threadless_process.start()
|
|
logger.debug('Started process %d', self.threadless_process.pid)
|
|
|
|
def shutdown_threadless_process(self) -> None:
|
|
assert self.threadless_process and self.threadless_client_queue
|
|
logger.debug('Stopped process %d', self.threadless_process.pid)
|
|
self.threadless_process.running.set()
|
|
self.threadless_process.join()
|
|
self.threadless_client_queue.close()
|
|
|
|
def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
|
|
if self.flags.threadless and \
|
|
self.threadless_client_queue and \
|
|
self.threadless_process:
|
|
self.threadless_client_queue.send(addr)
|
|
send_handle(
|
|
self.threadless_client_queue,
|
|
conn.fileno(),
|
|
self.threadless_process.pid
|
|
)
|
|
conn.close()
|
|
else:
|
|
work = self.work_klass(
|
|
fileno=conn.fileno(),
|
|
addr=addr,
|
|
flags=self.flags,
|
|
event_queue=self.event_queue
|
|
)
|
|
work_thread = threading.Thread(target=work.run)
|
|
work_thread.daemon = True
|
|
work.publish_event(
|
|
event_name=eventNames.WORK_STARTED,
|
|
event_payload={'fileno': conn.fileno(), 'addr': addr},
|
|
publisher_id=self.__class__.__name__
|
|
)
|
|
work_thread.start()
|
|
|
|
def run_once(self) -> None:
|
|
assert self.selector and self.sock
|
|
with self.lock:
|
|
events = self.selector.select(timeout=1)
|
|
if len(events) == 0:
|
|
return
|
|
conn, addr = self.sock.accept()
|
|
# now = time.time()
|
|
# fileno: int = conn.fileno()
|
|
self.start_work(conn, addr)
|
|
# logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now)
|
|
|
|
def run(self) -> None:
|
|
self.selector = selectors.DefaultSelector()
|
|
fileno = recv_handle(self.work_queue)
|
|
self.work_queue.close()
|
|
self.sock = socket.fromfd(
|
|
fileno,
|
|
family=self.flags.family,
|
|
type=socket.SOCK_STREAM
|
|
)
|
|
try:
|
|
self.selector.register(self.sock, selectors.EVENT_READ)
|
|
if self.flags.threadless:
|
|
self.start_threadless_process()
|
|
while not self.running.is_set():
|
|
self.run_once()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
self.selector.unregister(self.sock)
|
|
if self.flags.threadless:
|
|
self.shutdown_threadless_process()
|
|
self.sock.close()
|
|
logger.debug('Acceptor#%d shutdown', self.idd)
|