[Task] A generic payload based work abstraction (#1057)
* Refactor into an internal task submodule of work * As context managers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing license Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d616fc87db
commit
4af0c2f601
|
@ -710,6 +710,8 @@ Start `proxy.py` as:
|
|||
--plugins proxy.plugin.CacheResponsesPlugin
|
||||
```
|
||||
|
||||
You may also use the `--cache-requests` flag to enable request packet caching for inspection.
|
||||
|
||||
Verify using `curl -v -x localhost:8899 http://httpbin.org/get`:
|
||||
|
||||
```console
|
||||
|
|
128
examples/task.py
128
examples/task.py
|
@ -8,116 +8,64 @@
|
|||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
import threading
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from proxy.core.work import (
|
||||
Work, ThreadlessPool, BaseLocalExecutor, BaseRemoteExecutor,
|
||||
)
|
||||
from proxy.core.work import ThreadlessPool
|
||||
from proxy.common.flag import FlagParser
|
||||
from proxy.common.backports import NonBlockingQueue
|
||||
from proxy.core.work.task import (
|
||||
RemoteTaskExecutor, ThreadedTaskExecutor, SingleProcessTaskExecutor,
|
||||
)
|
||||
|
||||
|
||||
class Task:
|
||||
"""This will be our work object."""
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
self.payload = payload
|
||||
print(payload)
|
||||
|
||||
|
||||
class TaskWork(Work[Task]):
|
||||
"""This will be our handler class, created for each received work."""
|
||||
|
||||
@staticmethod
|
||||
def create(*args: Any) -> Task:
|
||||
"""Work core doesn't know how to create work objects for us, so
|
||||
we must provide an implementation of create method here."""
|
||||
return Task(*args)
|
||||
|
||||
|
||||
class LocalTaskExecutor(BaseLocalExecutor):
|
||||
"""We'll define a local executor which is capable of receiving
|
||||
log lines over a non blocking queue."""
|
||||
|
||||
def work(self, *args: Any) -> None:
|
||||
task_id = int(time.time())
|
||||
uid = '%s-%s' % (self.iid, task_id)
|
||||
self.works[task_id] = self.create(uid, *args)
|
||||
|
||||
|
||||
class RemoteTaskExecutor(BaseRemoteExecutor):
|
||||
|
||||
def work(self, *args: Any) -> None:
|
||||
task_id = int(time.time())
|
||||
uid = '%s-%s' % (self.iid, task_id)
|
||||
self.works[task_id] = self.create(uid, *args)
|
||||
|
||||
|
||||
def start_local(flags: argparse.Namespace) -> None:
|
||||
work_queue = NonBlockingQueue()
|
||||
executor = LocalTaskExecutor(iid=1, work_queue=work_queue, flags=flags)
|
||||
|
||||
t = threading.Thread(target=executor.run)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
try:
|
||||
def start_local_thread(flags: argparse.Namespace) -> None:
|
||||
with ThreadedTaskExecutor(flags=flags) as thread:
|
||||
i = 0
|
||||
while True:
|
||||
work_queue.put(('%d' % i).encode('utf-8'))
|
||||
thread.executor.work_queue.put(('%d' % i).encode('utf-8'))
|
||||
i += 1
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
executor.running.set()
|
||||
t.join()
|
||||
|
||||
|
||||
def start_remote(flags: argparse.Namespace) -> None:
|
||||
pipe = multiprocessing.Pipe()
|
||||
work_queue = pipe[0]
|
||||
executor = RemoteTaskExecutor(iid=1, work_queue=pipe[1], flags=flags)
|
||||
|
||||
p = multiprocessing.Process(target=executor.run)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
try:
|
||||
def start_remote_process(flags: argparse.Namespace) -> None:
|
||||
with SingleProcessTaskExecutor(flags=flags) as process:
|
||||
i = 0
|
||||
while True:
|
||||
work_queue.send(('%d' % i).encode('utf-8'))
|
||||
process.work_queue.send(('%d' % i).encode('utf-8'))
|
||||
i += 1
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
executor.running.set()
|
||||
p.join()
|
||||
|
||||
|
||||
def start_remote_pool(flags: argparse.Namespace) -> None:
|
||||
with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool:
|
||||
try:
|
||||
i = 0
|
||||
while True:
|
||||
work_queue = pool.work_queues[i % flags.num_workers]
|
||||
work_queue.send(('%d' % i).encode('utf-8'))
|
||||
i += 1
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
i = 0
|
||||
while True:
|
||||
work_queue = pool.work_queues[i % flags.num_workers]
|
||||
work_queue.send(('%d' % i).encode('utf-8'))
|
||||
i += 1
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
flags = FlagParser.initialize(
|
||||
sys.argv[2:] + ['--disable-http-proxy'],
|
||||
work_klass='proxy.core.work.task.TaskHandler',
|
||||
)
|
||||
globals()['start_%s' % sys.argv[1]](flags)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
# TODO: TaskWork, LocalTaskExecutor, RemoteTaskExecutor
|
||||
# should not be needed, abstract those pieces out in the core
|
||||
# for stateless tasks.
|
||||
if __name__ == '__main__':
|
||||
flags = FlagParser.initialize(
|
||||
['--disable-http-proxy'],
|
||||
work_klass=TaskWork,
|
||||
)
|
||||
start_remote_pool(flags)
|
||||
# start_remote(flags)
|
||||
# start_local(flags)
|
||||
if len(sys.argv) < 2:
|
||||
print(
|
||||
'\n'.join([
|
||||
'Usage:',
|
||||
' %s <execution-mode>' % sys.argv[0],
|
||||
' execution-mode can be one of the following:',
|
||||
' "remote_pool", "remote_process", "local_thread"',
|
||||
]),
|
||||
)
|
||||
sys.exit(1)
|
||||
main()
|
||||
|
|
|
@ -75,19 +75,18 @@ class TcpUpstreamConnectionHandler(ABC):
|
|||
self.upstream.connection.fileno() in r:
|
||||
try:
|
||||
raw = self.upstream.recv(self.server_recvbuf_size)
|
||||
if raw is not None:
|
||||
self.total_size += len(raw)
|
||||
self.handle_upstream_data(raw)
|
||||
else:
|
||||
if raw is None: # pragma: no cover
|
||||
# Tear down because upstream proxy closed the connection
|
||||
return True
|
||||
except TimeoutError:
|
||||
self.total_size += len(raw)
|
||||
self.handle_upstream_data(raw)
|
||||
except TimeoutError: # pragma: no cover
|
||||
logger.info('Upstream recv timeout error')
|
||||
return True
|
||||
except ssl.SSLWantReadError:
|
||||
except ssl.SSLWantReadError: # pragma: no cover
|
||||
logger.info('Upstream SSLWantReadError, will retry')
|
||||
return False
|
||||
except ConnectionResetError:
|
||||
except ConnectionResetError: # pragma: no cover
|
||||
logger.debug('Connection reset by upstream')
|
||||
return True
|
||||
return False
|
||||
|
@ -98,10 +97,10 @@ class TcpUpstreamConnectionHandler(ABC):
|
|||
self.upstream.has_buffer():
|
||||
try:
|
||||
self.upstream.flush()
|
||||
except ssl.SSLWantWriteError:
|
||||
except ssl.SSLWantWriteError: # pragma: no cover
|
||||
logger.info('Upstream SSLWantWriteError, will retry')
|
||||
return False
|
||||
except BrokenPipeError:
|
||||
except BrokenPipeError: # pragma: no cover
|
||||
logger.debug('BrokenPipeError when flushing to upstream')
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
from .task import Task
|
||||
from .local import LocalTaskExecutor, ThreadedTaskExecutor
|
||||
from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor
|
||||
from .handler import TaskHandler
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Task',
|
||||
'TaskHandler',
|
||||
'LocalTaskExecutor',
|
||||
'ThreadedTaskExecutor',
|
||||
'RemoteTaskExecutor',
|
||||
'SingleProcessTaskExecutor',
|
||||
]
|
|
@ -0,0 +1,25 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from .task import Task
|
||||
from ..work import Work
|
||||
|
||||
|
||||
class TaskHandler(Work[Task]):
|
||||
"""Task handler."""
|
||||
|
||||
@staticmethod
|
||||
def create(*args: Any) -> Task:
|
||||
"""Work core doesn't know how to create work objects for us.
|
||||
Example, for task module scenario, it doesn't know how to create
|
||||
Task objects for us."""
|
||||
return Task(*args)
|
|
@ -0,0 +1,50 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from ..local import BaseLocalExecutor
|
||||
from ....common.backports import NonBlockingQueue
|
||||
|
||||
|
||||
class LocalTaskExecutor(BaseLocalExecutor):
|
||||
"""We'll define a local executor which is capable of receiving
|
||||
log lines over a non blocking queue."""
|
||||
|
||||
def work(self, *args: Any) -> None:
|
||||
task_id = int(time.time())
|
||||
uid = '%s-%s' % (self.iid, task_id)
|
||||
self.works[task_id] = self.create(uid, *args)
|
||||
|
||||
|
||||
class ThreadedTaskExecutor(threading.Thread):
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.executor = LocalTaskExecutor(
|
||||
iid=uuid.uuid4().hex,
|
||||
work_queue=NonBlockingQueue(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __enter__(self) -> 'ThreadedTaskExecutor':
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.executor.running.set()
|
||||
self.join()
|
||||
|
||||
def run(self) -> None:
|
||||
self.executor.run()
|
|
@ -0,0 +1,48 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from ..remote import BaseRemoteExecutor
|
||||
|
||||
|
||||
class RemoteTaskExecutor(BaseRemoteExecutor):
|
||||
|
||||
def work(self, *args: Any) -> None:
|
||||
task_id = int(time.time())
|
||||
uid = '%s-%s' % (self.iid, task_id)
|
||||
self.works[task_id] = self.create(uid, *args)
|
||||
|
||||
|
||||
class SingleProcessTaskExecutor(multiprocessing.Process):
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.work_queue, remote = multiprocessing.Pipe()
|
||||
self.executor = RemoteTaskExecutor(
|
||||
iid=uuid.uuid4().hex,
|
||||
work_queue=remote,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __enter__(self) -> 'SingleProcessTaskExecutor':
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.executor.running.set()
|
||||
self.join()
|
||||
|
||||
def run(self) -> None:
|
||||
self.executor.run()
|
|
@ -0,0 +1,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
|
||||
class Task:
|
||||
"""Task object which known how to process the payload."""
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
self.payload = payload
|
||||
print(payload)
|
Loading…
Reference in New Issue