[Task] A generic payload based work abstraction (#1057)

* Refactor into an internal task submodule of work

* As context managers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add missing license

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Abhinav Singh 2022-01-25 01:51:41 +05:30 committed by GitHub
parent d616fc87db
commit 4af0c2f601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 213 additions and 99 deletions

View File

@ -710,6 +710,8 @@ Start `proxy.py` as:
--plugins proxy.plugin.CacheResponsesPlugin
```
You may also use the `--cache-requests` flag to enable request packet caching for inspection.
Verify using `curl -v -x localhost:8899 http://httpbin.org/get`:
```console

View File

@ -8,116 +8,64 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import time
import sys
import argparse
import threading
import multiprocessing
from typing import Any
from proxy.core.work import (
Work, ThreadlessPool, BaseLocalExecutor, BaseRemoteExecutor,
)
from proxy.core.work import ThreadlessPool
from proxy.common.flag import FlagParser
from proxy.common.backports import NonBlockingQueue
from proxy.core.work.task import (
RemoteTaskExecutor, ThreadedTaskExecutor, SingleProcessTaskExecutor,
)
class Task:
"""This will be our work object."""
def __init__(self, payload: bytes) -> None:
self.payload = payload
print(payload)
class TaskWork(Work[Task]):
"""This will be our handler class, created for each received work."""
@staticmethod
def create(*args: Any) -> Task:
"""Work core doesn't know how to create work objects for us, so
we must provide an implementation of create method here."""
return Task(*args)
class LocalTaskExecutor(BaseLocalExecutor):
"""We'll define a local executor which is capable of receiving
log lines over a non blocking queue."""
def work(self, *args: Any) -> None:
task_id = int(time.time())
uid = '%s-%s' % (self.iid, task_id)
self.works[task_id] = self.create(uid, *args)
class RemoteTaskExecutor(BaseRemoteExecutor):
def work(self, *args: Any) -> None:
task_id = int(time.time())
uid = '%s-%s' % (self.iid, task_id)
self.works[task_id] = self.create(uid, *args)
def start_local(flags: argparse.Namespace) -> None:
work_queue = NonBlockingQueue()
executor = LocalTaskExecutor(iid=1, work_queue=work_queue, flags=flags)
t = threading.Thread(target=executor.run)
t.daemon = True
t.start()
try:
def start_local_thread(flags: argparse.Namespace) -> None:
with ThreadedTaskExecutor(flags=flags) as thread:
i = 0
while True:
work_queue.put(('%d' % i).encode('utf-8'))
thread.executor.work_queue.put(('%d' % i).encode('utf-8'))
i += 1
except KeyboardInterrupt:
pass
finally:
executor.running.set()
t.join()
def start_remote(flags: argparse.Namespace) -> None:
pipe = multiprocessing.Pipe()
work_queue = pipe[0]
executor = RemoteTaskExecutor(iid=1, work_queue=pipe[1], flags=flags)
p = multiprocessing.Process(target=executor.run)
p.daemon = True
p.start()
try:
def start_remote_process(flags: argparse.Namespace) -> None:
with SingleProcessTaskExecutor(flags=flags) as process:
i = 0
while True:
work_queue.send(('%d' % i).encode('utf-8'))
process.work_queue.send(('%d' % i).encode('utf-8'))
i += 1
except KeyboardInterrupt:
pass
finally:
executor.running.set()
p.join()
def start_remote_pool(flags: argparse.Namespace) -> None:
with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool:
try:
i = 0
while True:
work_queue = pool.work_queues[i % flags.num_workers]
work_queue.send(('%d' % i).encode('utf-8'))
i += 1
except KeyboardInterrupt:
pass
i = 0
while True:
work_queue = pool.work_queues[i % flags.num_workers]
work_queue.send(('%d' % i).encode('utf-8'))
i += 1
def main() -> None:
try:
flags = FlagParser.initialize(
sys.argv[2:] + ['--disable-http-proxy'],
work_klass='proxy.core.work.task.TaskHandler',
)
globals()['start_%s' % sys.argv[1]](flags)
except KeyboardInterrupt:
pass
# TODO: TaskWork, LocalTaskExecutor, RemoteTaskExecutor
# should not be needed, abstract those pieces out in the core
# for stateless tasks.
if __name__ == '__main__':
flags = FlagParser.initialize(
['--disable-http-proxy'],
work_klass=TaskWork,
)
start_remote_pool(flags)
# start_remote(flags)
# start_local(flags)
if len(sys.argv) < 2:
print(
'\n'.join([
'Usage:',
' %s <execution-mode>' % sys.argv[0],
' execution-mode can be one of the following:',
' "remote_pool", "remote_process", "local_thread"',
]),
)
sys.exit(1)
main()

View File

@ -75,19 +75,18 @@ class TcpUpstreamConnectionHandler(ABC):
self.upstream.connection.fileno() in r:
try:
raw = self.upstream.recv(self.server_recvbuf_size)
if raw is not None:
self.total_size += len(raw)
self.handle_upstream_data(raw)
else:
if raw is None: # pragma: no cover
# Tear down because upstream proxy closed the connection
return True
except TimeoutError:
self.total_size += len(raw)
self.handle_upstream_data(raw)
except TimeoutError: # pragma: no cover
logger.info('Upstream recv timeout error')
return True
except ssl.SSLWantReadError:
except ssl.SSLWantReadError: # pragma: no cover
logger.info('Upstream SSLWantReadError, will retry')
return False
except ConnectionResetError:
except ConnectionResetError: # pragma: no cover
logger.debug('Connection reset by upstream')
return True
return False
@ -98,10 +97,10 @@ class TcpUpstreamConnectionHandler(ABC):
self.upstream.has_buffer():
try:
self.upstream.flush()
except ssl.SSLWantWriteError:
except ssl.SSLWantWriteError: # pragma: no cover
logger.info('Upstream SSLWantWriteError, will retry')
return False
except BrokenPipeError:
except BrokenPipeError: # pragma: no cover
logger.debug('BrokenPipeError when flushing to upstream')
return True
return False

View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from .task import Task
from .local import LocalTaskExecutor, ThreadedTaskExecutor
from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor
from .handler import TaskHandler
__all__ = [
'Task',
'TaskHandler',
'LocalTaskExecutor',
'ThreadedTaskExecutor',
'RemoteTaskExecutor',
'SingleProcessTaskExecutor',
]

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Any
from .task import Task
from ..work import Work
class TaskHandler(Work[Task]):
"""Task handler."""
@staticmethod
def create(*args: Any) -> Task:
"""Work core doesn't know how to create work objects for us.
Example, for task module scenario, it doesn't know how to create
Task objects for us."""
return Task(*args)

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import time
import uuid
import threading
from typing import Any
from ..local import BaseLocalExecutor
from ....common.backports import NonBlockingQueue
class LocalTaskExecutor(BaseLocalExecutor):
"""We'll define a local executor which is capable of receiving
log lines over a non blocking queue."""
def work(self, *args: Any) -> None:
task_id = int(time.time())
uid = '%s-%s' % (self.iid, task_id)
self.works[task_id] = self.create(uid, *args)
class ThreadedTaskExecutor(threading.Thread):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self.daemon = True
self.executor = LocalTaskExecutor(
iid=uuid.uuid4().hex,
work_queue=NonBlockingQueue(),
**kwargs,
)
def __enter__(self) -> 'ThreadedTaskExecutor':
self.start()
return self
def __exit__(self, *args: Any) -> None:
self.executor.running.set()
self.join()
def run(self) -> None:
self.executor.run()

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import time
import uuid
import multiprocessing
from typing import Any
from ..remote import BaseRemoteExecutor
class RemoteTaskExecutor(BaseRemoteExecutor):
def work(self, *args: Any) -> None:
task_id = int(time.time())
uid = '%s-%s' % (self.iid, task_id)
self.works[task_id] = self.create(uid, *args)
class SingleProcessTaskExecutor(multiprocessing.Process):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self.daemon = True
self.work_queue, remote = multiprocessing.Pipe()
self.executor = RemoteTaskExecutor(
iid=uuid.uuid4().hex,
work_queue=remote,
**kwargs,
)
def __enter__(self) -> 'SingleProcessTaskExecutor':
self.start()
return self
def __exit__(self, *args: Any) -> None:
self.executor.running.set()
self.join()
def run(self) -> None:
self.executor.run()

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
class Task:
"""Task object which known how to process the payload."""
def __init__(self, payload: bytes) -> None:
self.payload = payload
print(payload)