From 4af0c2f601e1bd5b312976bd47c37f60b64e0e3c Mon Sep 17 00:00:00 2001 From: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> Date: Tue, 25 Jan 2022 01:51:41 +0530 Subject: [PATCH] [Task] A generic payload based work abstraction (#1057) * Refactor into an internal task submodule of work * As context managers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing license Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 2 + examples/task.py | 128 +++++++++---------------------- proxy/core/base/tcp_upstream.py | 17 ++-- proxy/core/work/task/__init__.py | 24 ++++++ proxy/core/work/task/handler.py | 25 ++++++ proxy/core/work/task/local.py | 50 ++++++++++++ proxy/core/work/task/remote.py | 48 ++++++++++++ proxy/core/work/task/task.py | 18 +++++ 8 files changed, 213 insertions(+), 99 deletions(-) create mode 100644 proxy/core/work/task/__init__.py create mode 100644 proxy/core/work/task/handler.py create mode 100644 proxy/core/work/task/local.py create mode 100644 proxy/core/work/task/remote.py create mode 100644 proxy/core/work/task/task.py diff --git a/README.md b/README.md index fcd25b18..bcebb0b8 100644 --- a/README.md +++ b/README.md @@ -710,6 +710,8 @@ Start `proxy.py` as: --plugins proxy.plugin.CacheResponsesPlugin ``` +You may also use the `--cache-requests` flag to enable request packet caching for inspection. + Verify using `curl -v -x localhost:8899 http://httpbin.org/get`: ```console diff --git a/examples/task.py b/examples/task.py index 67441d82..91572255 100644 --- a/examples/task.py +++ b/examples/task.py @@ -8,116 +8,64 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import time +import sys import argparse -import threading -import multiprocessing -from typing import Any -from proxy.core.work import ( - Work, ThreadlessPool, BaseLocalExecutor, BaseRemoteExecutor, -) +from proxy.core.work import ThreadlessPool from proxy.common.flag import FlagParser -from proxy.common.backports import NonBlockingQueue +from proxy.core.work.task import ( + RemoteTaskExecutor, ThreadedTaskExecutor, SingleProcessTaskExecutor, +) -class Task: - """This will be our work object.""" - - def __init__(self, payload: bytes) -> None: - self.payload = payload - print(payload) - - -class TaskWork(Work[Task]): - """This will be our handler class, created for each received work.""" - - @staticmethod - def create(*args: Any) -> Task: - """Work core doesn't know how to create work objects for us, so - we must provide an implementation of create method here.""" - return Task(*args) - - -class LocalTaskExecutor(BaseLocalExecutor): - """We'll define a local executor which is capable of receiving - log lines over a non blocking queue.""" - - def work(self, *args: Any) -> None: - task_id = int(time.time()) - uid = '%s-%s' % (self.iid, task_id) - self.works[task_id] = self.create(uid, *args) - - -class RemoteTaskExecutor(BaseRemoteExecutor): - - def work(self, *args: Any) -> None: - task_id = int(time.time()) - uid = '%s-%s' % (self.iid, task_id) - self.works[task_id] = self.create(uid, *args) - - -def start_local(flags: argparse.Namespace) -> None: - work_queue = NonBlockingQueue() - executor = LocalTaskExecutor(iid=1, work_queue=work_queue, flags=flags) - - t = threading.Thread(target=executor.run) - t.daemon = True - t.start() - - try: +def start_local_thread(flags: argparse.Namespace) -> None: + with ThreadedTaskExecutor(flags=flags) as thread: i = 0 while True: - work_queue.put(('%d' % i).encode('utf-8')) + thread.executor.work_queue.put(('%d' % i).encode('utf-8')) i += 1 - except KeyboardInterrupt: - pass - finally: - executor.running.set() - t.join() -def start_remote(flags: argparse.Namespace) -> None: - pipe = multiprocessing.Pipe() - work_queue = pipe[0] - executor = RemoteTaskExecutor(iid=1, work_queue=pipe[1], flags=flags) - - p = multiprocessing.Process(target=executor.run) - p.daemon = True - p.start() - - try: +def start_remote_process(flags: argparse.Namespace) -> None: + with SingleProcessTaskExecutor(flags=flags) as process: i = 0 while True: - work_queue.send(('%d' % i).encode('utf-8')) + process.work_queue.send(('%d' % i).encode('utf-8')) i += 1 - except KeyboardInterrupt: - pass - finally: - executor.running.set() - p.join() def start_remote_pool(flags: argparse.Namespace) -> None: with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool: - try: - i = 0 - while True: - work_queue = pool.work_queues[i % flags.num_workers] - work_queue.send(('%d' % i).encode('utf-8')) - i += 1 - except KeyboardInterrupt: - pass + i = 0 + while True: + work_queue = pool.work_queues[i % flags.num_workers] + work_queue.send(('%d' % i).encode('utf-8')) + i += 1 + + +def main() -> None: + try: + flags = FlagParser.initialize( + sys.argv[2:] + ['--disable-http-proxy'], + work_klass='proxy.core.work.task.TaskHandler', + ) + globals()['start_%s' % sys.argv[1]](flags) + except KeyboardInterrupt: + pass # TODO: TaskWork, LocalTaskExecutor, RemoteTaskExecutor # should not be needed, abstract those pieces out in the core # for stateless tasks. if __name__ == '__main__': - flags = FlagParser.initialize( - ['--disable-http-proxy'], - work_klass=TaskWork, - ) - start_remote_pool(flags) - # start_remote(flags) - # start_local(flags) + if len(sys.argv) < 2: + print( + '\n'.join([ + 'Usage:', + ' %s ' % sys.argv[0], + ' execution-mode can be one of the following:', + ' "remote_pool", "remote_process", "local_thread"', + ]), + ) + sys.exit(1) + main() diff --git a/proxy/core/base/tcp_upstream.py b/proxy/core/base/tcp_upstream.py index f045f1f9..605cd75e 100644 --- a/proxy/core/base/tcp_upstream.py +++ b/proxy/core/base/tcp_upstream.py @@ -75,19 +75,18 @@ class TcpUpstreamConnectionHandler(ABC): self.upstream.connection.fileno() in r: try: raw = self.upstream.recv(self.server_recvbuf_size) - if raw is not None: - self.total_size += len(raw) - self.handle_upstream_data(raw) - else: + if raw is None: # pragma: no cover # Tear down because upstream proxy closed the connection return True - except TimeoutError: + self.total_size += len(raw) + self.handle_upstream_data(raw) + except TimeoutError: # pragma: no cover logger.info('Upstream recv timeout error') return True - except ssl.SSLWantReadError: + except ssl.SSLWantReadError: # pragma: no cover logger.info('Upstream SSLWantReadError, will retry') return False - except ConnectionResetError: + except ConnectionResetError: # pragma: no cover logger.debug('Connection reset by upstream') return True return False @@ -98,10 +97,10 @@ class TcpUpstreamConnectionHandler(ABC): self.upstream.has_buffer(): try: self.upstream.flush() - except ssl.SSLWantWriteError: + except ssl.SSLWantWriteError: # pragma: no cover logger.info('Upstream SSLWantWriteError, will retry') return False - except BrokenPipeError: + except BrokenPipeError: # pragma: no cover logger.debug('BrokenPipeError when flushing to upstream') return True return False diff --git a/proxy/core/work/task/__init__.py b/proxy/core/work/task/__init__.py new file mode 100644 index 00000000..157ae566 --- /dev/null +++ b/proxy/core/work/task/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from .task import Task +from .local import LocalTaskExecutor, ThreadedTaskExecutor +from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor +from .handler import TaskHandler + + +__all__ = [ + 'Task', + 'TaskHandler', + 'LocalTaskExecutor', + 'ThreadedTaskExecutor', + 'RemoteTaskExecutor', + 'SingleProcessTaskExecutor', +] diff --git a/proxy/core/work/task/handler.py b/proxy/core/work/task/handler.py new file mode 100644 index 00000000..5fd78e38 --- /dev/null +++ b/proxy/core/work/task/handler.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Any + +from .task import Task +from ..work import Work + + +class TaskHandler(Work[Task]): + """Task handler.""" + + @staticmethod + def create(*args: Any) -> Task: + """Work core doesn't know how to create work objects for us. + Example, for task module scenario, it doesn't know how to create + Task objects for us.""" + return Task(*args) diff --git a/proxy/core/work/task/local.py b/proxy/core/work/task/local.py new file mode 100644 index 00000000..a2642b23 --- /dev/null +++ b/proxy/core/work/task/local.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import threading +from typing import Any + +from ..local import BaseLocalExecutor +from ....common.backports import NonBlockingQueue + + +class LocalTaskExecutor(BaseLocalExecutor): + """We'll define a local executor which is capable of receiving + log lines over a non blocking queue.""" + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class ThreadedTaskExecutor(threading.Thread): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.executor = LocalTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=NonBlockingQueue(), + **kwargs, + ) + + def __enter__(self) -> 'ThreadedTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/remote.py b/proxy/core/work/task/remote.py new file mode 100644 index 00000000..ce4b0009 --- /dev/null +++ b/proxy/core/work/task/remote.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +import time +import uuid +import multiprocessing +from typing import Any + +from ..remote import BaseRemoteExecutor + + +class RemoteTaskExecutor(BaseRemoteExecutor): + + def work(self, *args: Any) -> None: + task_id = int(time.time()) + uid = '%s-%s' % (self.iid, task_id) + self.works[task_id] = self.create(uid, *args) + + +class SingleProcessTaskExecutor(multiprocessing.Process): + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.daemon = True + self.work_queue, remote = multiprocessing.Pipe() + self.executor = RemoteTaskExecutor( + iid=uuid.uuid4().hex, + work_queue=remote, + **kwargs, + ) + + def __enter__(self) -> 'SingleProcessTaskExecutor': + self.start() + return self + + def __exit__(self, *args: Any) -> None: + self.executor.running.set() + self.join() + + def run(self) -> None: + self.executor.run() diff --git a/proxy/core/work/task/task.py b/proxy/core/work/task/task.py new file mode 100644 index 00000000..f4467ef2 --- /dev/null +++ b/proxy/core/work/task/task.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" + + +class Task: + """Task object which known how to process the payload.""" + + def __init__(self, payload: bytes) -> None: + self.payload = payload + print(payload)