Make HTTP handler constructor free of socket file number (#219)
* Refactor into acceptor module * Add tunnel doc * Make fileno free * Autopep8
This commit is contained in:
parent
cc7e4f5cbf
commit
269484df2e
91
README.md
91
README.md
|
@ -3,9 +3,9 @@
|
|||
[![License](https://img.shields.io/github/license/abhinavsingh/proxy.py.svg)](https://opensource.org/licenses/BSD-3-Clause)
|
||||
[![Build Status](https://travis-ci.org/abhinavsingh/proxy.py.svg?branch=develop)](https://travis-ci.org/abhinavsingh/proxy.py/)
|
||||
[![No Dependencies](https://img.shields.io/static/v1?label=dependencies&message=none&color=green)](https://github.com/abhinavsingh/proxy.py)
|
||||
[![Coverage](https://codecov.io/gh/abhinavsingh/proxy.py/branch/develop/graph/badge.svg)](https://codecov.io/gh/abhinavsingh/proxy.py)
|
||||
[![PyPi Monthly](https://img.shields.io/pypi/dm/proxy.py.svg?color=green)](https://pypi.org/project/proxy.py/)
|
||||
[![Docker Pulls](https://img.shields.io/docker/pulls/abhinavsingh/proxy.py?color=green)](https://hub.docker.com/r/abhinavsingh/proxy.py)
|
||||
[![Coverage](https://codecov.io/gh/abhinavsingh/proxy.py/branch/develop/graph/badge.svg)](https://codecov.io/gh/abhinavsingh/proxy.py)
|
||||
|
||||
[![Tested With MacOS, Ubuntu, Windows, Android, Android Emulator, iOS, iOS Simulator](https://img.shields.io/static/v1?label=tested%20with&message=mac%20OS%20%F0%9F%92%BB%20%7C%20Ubuntu%20%F0%9F%96%A5%20%7C%20Windows%20%F0%9F%92%BB&color=brightgreen)](https://abhinavsingh.com/proxy-py-a-lightweight-single-file-http-proxy-server-in-python/)
|
||||
[![Android, Android Emulator](https://img.shields.io/static/v1?label=tested%20with&message=Android%20%F0%9F%93%B1%20%7C%20Android%20Emulator%20%F0%9F%93%B1&color=brightgreen)](https://abhinavsingh.com/proxy-py-a-lightweight-single-file-http-proxy-server-in-python/)
|
||||
|
@ -58,6 +58,9 @@ Table of Contents
|
|||
* [Plugin Ordering](#plugin-ordering)
|
||||
* [End-to-End Encryption](#end-to-end-encryption)
|
||||
* [TLS Interception](#tls-interception)
|
||||
* [Proxy Over SSH Tunnel](#proxy-over-ssh-tunnel)
|
||||
* [Proxy Remote Requests Locally](#proxy-remote-requests-locally)
|
||||
* [Proxy Local Requests Remotely](#proxy-local-requests-remotely)
|
||||
* [Embed proxy.py](#embed-proxypy)
|
||||
* [Blocking Mode](#blocking-mode)
|
||||
* [Non-blocking Mode](#non-blocking-mode)
|
||||
|
@ -798,6 +801,92 @@ cached file instead of plain text.
|
|||
Now use CA flags with other
|
||||
[plugin examples](#plugin-examples) to see them work with `https` traffic.
|
||||
|
||||
Proxy Over SSH Tunnel
|
||||
=====================
|
||||
|
||||
Requires `paramiko` to work. See [requirements-tunnel.txt](https://github.com/abhinavsingh/proxy.py/blob/develop/requirements-tunnel.txt)
|
||||
|
||||
## Proxy Remote Requests Locally
|
||||
|
||||
|
|
||||
+------------+ | +----------+
|
||||
| LOCAL | | | REMOTE |
|
||||
| HOST | <== SSH ==== :8900 == | SERVER |
|
||||
+------------+ | +----------+
|
||||
:8899 proxy.py |
|
||||
|
|
||||
FIREWALL
|
||||
(allow tcp/22)
|
||||
|
||||
## What
|
||||
|
||||
Proxy HTTP(s) requests made on a `remote` server through `proxy.py` server
|
||||
running on `localhost`.
|
||||
|
||||
### How
|
||||
|
||||
* Requested `remote` port is forwarded over the SSH connection.
|
||||
* `proxy.py` running on the `localhost` handles and responds to
|
||||
`remote` proxy requests.
|
||||
|
||||
### Requirements
|
||||
|
||||
1. `localhost` MUST have SSH access to the `remote` server
|
||||
2. `remote` server MUST be configured to proxy HTTP(s) requests
|
||||
through the forwarded port number e.g. `:8900`.
|
||||
- `remote` and `localhost` ports CAN be same e.g. `:8899`.
|
||||
- `:8900` is chosen in ascii art for differentiation purposes.
|
||||
|
||||
### Try it
|
||||
|
||||
Start `proxy.py` as:
|
||||
|
||||
```
|
||||
$ # On localhost
|
||||
$ proxy --enable-tunnel \
|
||||
--tunnel-username username \
|
||||
--tunnel-hostname ip.address.or.domain.name \
|
||||
--tunnel-port 22 \
|
||||
--tunnel-remote-host 127.0.0.1
|
||||
--tunnel-remote-port 8899
|
||||
```
|
||||
|
||||
Make a HTTP proxy request on `remote` server and
|
||||
verify that response contains public IP address of `localhost` as origin:
|
||||
|
||||
```
|
||||
$ # On remote
|
||||
$ curl -x 127.0.0.1:8899 http://httpbin.org/get
|
||||
{
|
||||
"args": {},
|
||||
"headers": {
|
||||
"Accept": "*/*",
|
||||
"Host": "httpbin.org",
|
||||
"User-Agent": "curl/7.54.0"
|
||||
},
|
||||
"origin": "x.x.x.x, y.y.y.y",
|
||||
"url": "https://httpbin.org/get"
|
||||
}
|
||||
```
|
||||
|
||||
Also, verify that `proxy.py` logs on `localhost` contains `remote` IP as client IP.
|
||||
|
||||
```
|
||||
access_log:328 - remote:52067 - GET httpbin.org:80
|
||||
```
|
||||
|
||||
## Proxy Local Requests Remotely
|
||||
|
||||
|
|
||||
+------------+ | +----------+
|
||||
| LOCAL | | | REMOTE |
|
||||
| HOST | === SSH =====> | SERVER |
|
||||
+------------+ | +----------+
|
||||
| :8899 proxy.py
|
||||
|
|
||||
FIREWALL
|
||||
(allow tcp/22)
|
||||
|
||||
Embed proxy.py
|
||||
==============
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
from .acceptor import Acceptor
|
||||
from .pool import AcceptorPool
|
||||
|
||||
__all__ = [
|
||||
'Acceptor',
|
||||
'AcceptorPool',
|
||||
]
|
|
@ -0,0 +1,137 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import logging
|
||||
import multiprocessing
|
||||
import selectors
|
||||
import socket
|
||||
import threading
|
||||
# import time
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.reduction import send_handle, recv_handle
|
||||
from typing import Optional, Type, Tuple
|
||||
|
||||
from ..connection import TcpClientConnection
|
||||
from ..threadless import ThreadlessWork, Threadless
|
||||
from ..event import EventQueue, eventNames
|
||||
from ...common.flags import Flags
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Acceptor(multiprocessing.Process):
|
||||
"""Socket client acceptor.
|
||||
|
||||
Accepts client connection over received server socket handle and
|
||||
starts a new work thread.
|
||||
"""
|
||||
|
||||
lock = multiprocessing.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idd: int,
|
||||
work_queue: connection.Connection,
|
||||
flags: Flags,
|
||||
work_klass: Type[ThreadlessWork],
|
||||
event_queue: Optional[EventQueue] = None) -> None:
|
||||
super().__init__()
|
||||
self.idd = idd
|
||||
self.work_queue: connection.Connection = work_queue
|
||||
self.flags = flags
|
||||
self.work_klass = work_klass
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.running = multiprocessing.Event()
|
||||
self.selector: Optional[selectors.DefaultSelector] = None
|
||||
self.sock: Optional[socket.socket] = None
|
||||
self.threadless_process: Optional[Threadless] = None
|
||||
self.threadless_client_queue: Optional[connection.Connection] = None
|
||||
|
||||
def start_threadless_process(self) -> None:
|
||||
pipe = multiprocessing.Pipe()
|
||||
self.threadless_client_queue = pipe[0]
|
||||
self.threadless_process = Threadless(
|
||||
client_queue=pipe[1],
|
||||
flags=self.flags,
|
||||
work_klass=self.work_klass,
|
||||
event_queue=self.event_queue
|
||||
)
|
||||
self.threadless_process.start()
|
||||
logger.debug('Started process %d', self.threadless_process.pid)
|
||||
|
||||
def shutdown_threadless_process(self) -> None:
|
||||
assert self.threadless_process and self.threadless_client_queue
|
||||
logger.debug('Stopped process %d', self.threadless_process.pid)
|
||||
self.threadless_process.running.set()
|
||||
self.threadless_process.join()
|
||||
self.threadless_client_queue.close()
|
||||
|
||||
def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
|
||||
if self.flags.threadless and \
|
||||
self.threadless_client_queue and \
|
||||
self.threadless_process:
|
||||
self.threadless_client_queue.send(addr)
|
||||
send_handle(
|
||||
self.threadless_client_queue,
|
||||
conn.fileno(),
|
||||
self.threadless_process.pid
|
||||
)
|
||||
conn.close()
|
||||
else:
|
||||
work = self.work_klass(
|
||||
TcpClientConnection(conn, addr),
|
||||
flags=self.flags,
|
||||
event_queue=self.event_queue
|
||||
)
|
||||
work_thread = threading.Thread(target=work.run)
|
||||
work_thread.daemon = True
|
||||
work.publish_event(
|
||||
event_name=eventNames.WORK_STARTED,
|
||||
event_payload={'fileno': conn.fileno(), 'addr': addr},
|
||||
publisher_id=self.__class__.__name__
|
||||
)
|
||||
work_thread.start()
|
||||
|
||||
def run_once(self) -> None:
|
||||
assert self.selector and self.sock
|
||||
with self.lock:
|
||||
events = self.selector.select(timeout=1)
|
||||
if len(events) == 0:
|
||||
return
|
||||
conn, addr = self.sock.accept()
|
||||
# now = time.time()
|
||||
# fileno: int = conn.fileno()
|
||||
self.start_work(conn, addr)
|
||||
# logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now)
|
||||
|
||||
def run(self) -> None:
|
||||
self.selector = selectors.DefaultSelector()
|
||||
fileno = recv_handle(self.work_queue)
|
||||
self.work_queue.close()
|
||||
self.sock = socket.fromfd(
|
||||
fileno,
|
||||
family=self.flags.family,
|
||||
type=socket.SOCK_STREAM
|
||||
)
|
||||
try:
|
||||
self.selector.register(self.sock, selectors.EVENT_READ)
|
||||
if self.flags.threadless:
|
||||
self.start_threadless_process()
|
||||
while not self.running.is_set():
|
||||
self.run_once()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.selector.unregister(self.sock)
|
||||
if self.flags.threadless:
|
||||
self.shutdown_threadless_process()
|
||||
self.sock.close()
|
||||
logger.debug('Acceptor#%d shutdown', self.idd)
|
|
@ -10,17 +10,17 @@
|
|||
"""
|
||||
import logging
|
||||
import multiprocessing
|
||||
import selectors
|
||||
import socket
|
||||
import threading
|
||||
# import time
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.reduction import send_handle, recv_handle
|
||||
from typing import List, Optional, Type, Tuple
|
||||
from multiprocessing.reduction import send_handle
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from .threadless import ThreadlessWork, Threadless
|
||||
from .event import EventQueue, EventDispatcher, eventNames
|
||||
from ..common.flags import Flags
|
||||
from .acceptor import Acceptor
|
||||
from ..threadless import ThreadlessWork
|
||||
from ..event import EventQueue, EventDispatcher
|
||||
from ...common.flags import Flags
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -125,115 +125,3 @@ class AcceptorPool:
|
|||
)
|
||||
self.work_queues[index].close()
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class Acceptor(multiprocessing.Process):
|
||||
"""Socket client acceptor.
|
||||
|
||||
Accepts client connection over received server socket handle and
|
||||
starts a new work thread.
|
||||
"""
|
||||
|
||||
lock = multiprocessing.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idd: int,
|
||||
work_queue: connection.Connection,
|
||||
flags: Flags,
|
||||
work_klass: Type[ThreadlessWork],
|
||||
event_queue: Optional[EventQueue] = None) -> None:
|
||||
super().__init__()
|
||||
self.idd = idd
|
||||
self.work_queue: connection.Connection = work_queue
|
||||
self.flags = flags
|
||||
self.work_klass = work_klass
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.running = multiprocessing.Event()
|
||||
self.selector: Optional[selectors.DefaultSelector] = None
|
||||
self.sock: Optional[socket.socket] = None
|
||||
self.threadless_process: Optional[Threadless] = None
|
||||
self.threadless_client_queue: Optional[connection.Connection] = None
|
||||
|
||||
def start_threadless_process(self) -> None:
|
||||
pipe = multiprocessing.Pipe()
|
||||
self.threadless_client_queue = pipe[0]
|
||||
self.threadless_process = Threadless(
|
||||
client_queue=pipe[1],
|
||||
flags=self.flags,
|
||||
work_klass=self.work_klass,
|
||||
event_queue=self.event_queue
|
||||
)
|
||||
self.threadless_process.start()
|
||||
logger.debug('Started process %d', self.threadless_process.pid)
|
||||
|
||||
def shutdown_threadless_process(self) -> None:
|
||||
assert self.threadless_process and self.threadless_client_queue
|
||||
logger.debug('Stopped process %d', self.threadless_process.pid)
|
||||
self.threadless_process.running.set()
|
||||
self.threadless_process.join()
|
||||
self.threadless_client_queue.close()
|
||||
|
||||
def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None:
|
||||
if self.flags.threadless and \
|
||||
self.threadless_client_queue and \
|
||||
self.threadless_process:
|
||||
self.threadless_client_queue.send(addr)
|
||||
send_handle(
|
||||
self.threadless_client_queue,
|
||||
conn.fileno(),
|
||||
self.threadless_process.pid
|
||||
)
|
||||
conn.close()
|
||||
else:
|
||||
work = self.work_klass(
|
||||
fileno=conn.fileno(),
|
||||
addr=addr,
|
||||
flags=self.flags,
|
||||
event_queue=self.event_queue
|
||||
)
|
||||
work_thread = threading.Thread(target=work.run)
|
||||
work_thread.daemon = True
|
||||
work.publish_event(
|
||||
event_name=eventNames.WORK_STARTED,
|
||||
event_payload={'fileno': conn.fileno(), 'addr': addr},
|
||||
publisher_id=self.__class__.__name__
|
||||
)
|
||||
work_thread.start()
|
||||
|
||||
def run_once(self) -> None:
|
||||
assert self.selector and self.sock
|
||||
with self.lock:
|
||||
events = self.selector.select(timeout=1)
|
||||
if len(events) == 0:
|
||||
return
|
||||
conn, addr = self.sock.accept()
|
||||
# now = time.time()
|
||||
# fileno: int = conn.fileno()
|
||||
self.start_work(conn, addr)
|
||||
# logger.info('Work started for fd %d in %f seconds', fileno, time.time() - now)
|
||||
|
||||
def run(self) -> None:
|
||||
self.selector = selectors.DefaultSelector()
|
||||
fileno = recv_handle(self.work_queue)
|
||||
self.work_queue.close()
|
||||
self.sock = socket.fromfd(
|
||||
fileno,
|
||||
family=self.flags.family,
|
||||
type=socket.SOCK_STREAM
|
||||
)
|
||||
try:
|
||||
self.selector.register(self.sock, selectors.EVENT_READ)
|
||||
if self.flags.threadless:
|
||||
self.start_threadless_process()
|
||||
while not self.running.is_set():
|
||||
self.run_once()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.selector.unregister(self.sock)
|
||||
if self.flags.threadless:
|
||||
self.shutdown_threadless_process()
|
||||
self.sock.close()
|
||||
logger.debug('Acceptor#%d shutdown', self.idd)
|
|
@ -22,6 +22,7 @@ from multiprocessing.reduction import recv_handle
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Tuple, List, Union, Generator, Any, Type
|
||||
|
||||
from .connection import TcpClientConnection
|
||||
from .event import EventQueue, eventNames
|
||||
|
||||
from ..common.flags import Flags
|
||||
|
@ -37,15 +38,12 @@ class ThreadlessWork(ABC):
|
|||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
fileno: int,
|
||||
addr: Tuple[str, int],
|
||||
client: TcpClientConnection,
|
||||
flags: Optional[Flags],
|
||||
event_queue: Optional[EventQueue] = None,
|
||||
uid: Optional[str] = None) -> None:
|
||||
self.fileno = fileno
|
||||
self.addr = addr
|
||||
self.client = client
|
||||
self.flags = flags if flags else Flags()
|
||||
|
||||
self.event_queue = event_queue
|
||||
self.uid: str = uid if uid is not None else uuid.uuid4().hex
|
||||
|
||||
|
@ -167,12 +165,16 @@ class Threadless(multiprocessing.Process):
|
|||
except asyncio.TimeoutError:
|
||||
self.cleanup(work_id)
|
||||
|
||||
def fromfd(self, fileno: int) -> socket.socket:
|
||||
return socket.fromfd(
|
||||
fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6,
|
||||
type=socket.SOCK_STREAM)
|
||||
|
||||
def accept_client(self) -> None:
|
||||
addr = self.client_queue.recv()
|
||||
fileno = recv_handle(self.client_queue)
|
||||
self.works[fileno] = self.work_klass(
|
||||
fileno=fileno,
|
||||
addr=addr,
|
||||
TcpClientConnection(conn=self.fromfd(fileno), addr=addr),
|
||||
flags=self.flags,
|
||||
event_queue=self.event_queue
|
||||
)
|
||||
|
|
|
@ -113,20 +113,18 @@ class HttpProtocolHandler(ThreadlessWork):
|
|||
Accepts `Client` connection object and manages HttpProtocolHandlerPlugin invocations.
|
||||
"""
|
||||
|
||||
def __init__(self, fileno: int, addr: Tuple[str, int],
|
||||
def __init__(self, client: TcpClientConnection,
|
||||
flags: Optional[Flags] = None,
|
||||
event_queue: Optional[EventQueue] = None,
|
||||
uid: Optional[str] = None):
|
||||
super().__init__(fileno, addr, flags, event_queue, uid)
|
||||
super().__init__(client, flags, event_queue, uid)
|
||||
|
||||
self.start_time: float = time.time()
|
||||
self.last_activity: float = self.start_time
|
||||
self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER)
|
||||
self.response: HttpParser = HttpParser(httpParserTypes.RESPONSE_PARSER)
|
||||
self.selector = selectors.DefaultSelector()
|
||||
self.client: TcpClientConnection = TcpClientConnection(
|
||||
self.fromfd(self.fileno), self.addr
|
||||
)
|
||||
self.client: TcpClientConnection = client
|
||||
self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {}
|
||||
|
||||
def initialize(self) -> None:
|
||||
|
@ -134,7 +132,7 @@ class HttpProtocolHandler(ThreadlessWork):
|
|||
conn = self.optionally_wrap_socket(self.client.connection)
|
||||
conn.setblocking(False)
|
||||
if self.flags.encryption_enabled():
|
||||
self.client = TcpClientConnection(conn=conn, addr=self.addr)
|
||||
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
|
||||
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
|
||||
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
|
||||
instance = klass(
|
||||
|
@ -232,12 +230,6 @@ class HttpProtocolHandler(ThreadlessWork):
|
|||
logger.debug('Client connection closed')
|
||||
super().shutdown()
|
||||
|
||||
def fromfd(self, fileno: int) -> socket.socket:
|
||||
conn = socket.fromfd(
|
||||
fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6,
|
||||
type=socket.SOCK_STREAM)
|
||||
return conn
|
||||
|
||||
def optionally_wrap_socket(
|
||||
self, conn: socket.socket) -> Union[ssl.SSLSocket, socket.socket]:
|
||||
"""Attempts to wrap accepted client connection using provided certificates.
|
||||
|
|
|
@ -106,7 +106,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
raw = self.server.recv(self.flags.server_recvbuf_size)
|
||||
except TimeoutError as e:
|
||||
if e.errno == errno.ETIMEDOUT:
|
||||
logger.warning('%s:%d timed out on recv' % self.server.addr)
|
||||
logger.warning(
|
||||
'%s:%d timed out on recv' %
|
||||
self.server.addr)
|
||||
return True
|
||||
else:
|
||||
raise e
|
||||
|
@ -115,7 +117,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
return False
|
||||
except OSError as e:
|
||||
if e.errno == errno.EHOSTUNREACH:
|
||||
logger.warning('%s:%d unreachable on recv' % self.server.addr)
|
||||
logger.warning(
|
||||
'%s:%d unreachable on recv' %
|
||||
self.server.addr)
|
||||
return True
|
||||
elif e.errno == errno.ECONNRESET:
|
||||
logger.warning('Connection reset by upstream: %r' % e)
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestAcceptor(unittest.TestCase):
|
|||
|
||||
@mock.patch('selectors.DefaultSelector')
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('proxy.core.acceptor.recv_handle')
|
||||
@mock.patch('proxy.core.acceptor.acceptor.recv_handle')
|
||||
def test_continues_when_no_events(
|
||||
self,
|
||||
mock_recv_handle: mock.Mock,
|
||||
|
@ -54,16 +54,18 @@ class TestAcceptor(unittest.TestCase):
|
|||
sock.accept.assert_not_called()
|
||||
self.mock_protocol_handler.assert_not_called()
|
||||
|
||||
@mock.patch('proxy.core.acceptor.acceptor.TcpClientConnection')
|
||||
@mock.patch('threading.Thread')
|
||||
@mock.patch('selectors.DefaultSelector')
|
||||
@mock.patch('socket.fromfd')
|
||||
@mock.patch('proxy.core.acceptor.recv_handle')
|
||||
@mock.patch('proxy.core.acceptor.acceptor.recv_handle')
|
||||
def test_accepts_client_from_server_socket(
|
||||
self,
|
||||
mock_recv_handle: mock.Mock,
|
||||
mock_fromfd: mock.Mock,
|
||||
mock_selector: mock.Mock,
|
||||
mock_thread: mock.Mock) -> None:
|
||||
mock_thread: mock.Mock,
|
||||
mock_client: mock.Mock) -> None:
|
||||
fileno = 10
|
||||
conn = mock.MagicMock()
|
||||
addr = mock.MagicMock()
|
||||
|
@ -87,8 +89,7 @@ class TestAcceptor(unittest.TestCase):
|
|||
type=socket.SOCK_STREAM
|
||||
)
|
||||
self.mock_protocol_handler.assert_called_with(
|
||||
fileno=conn.fileno(),
|
||||
addr=addr,
|
||||
mock_client.return_value,
|
||||
flags=self.flags,
|
||||
event_queue=None,
|
||||
)
|
||||
|
|
|
@ -18,49 +18,50 @@ from proxy.core.acceptor import AcceptorPool
|
|||
|
||||
class TestAcceptorPool(unittest.TestCase):
|
||||
|
||||
@mock.patch('proxy.core.acceptor.send_handle')
|
||||
@mock.patch('proxy.core.acceptor.pool.send_handle')
|
||||
@mock.patch('multiprocessing.Pipe')
|
||||
@mock.patch('socket.socket')
|
||||
@mock.patch('proxy.core.acceptor.Acceptor')
|
||||
@mock.patch('proxy.core.acceptor.pool.Acceptor')
|
||||
def test_setup_and_shutdown(
|
||||
self,
|
||||
mock_worker: mock.Mock,
|
||||
mock_acceptor: mock.Mock,
|
||||
mock_socket: mock.Mock,
|
||||
mock_pipe: mock.Mock,
|
||||
_mock_send_handle: mock.Mock) -> None:
|
||||
mock_worker1 = mock.MagicMock()
|
||||
mock_worker2 = mock.MagicMock()
|
||||
mock_worker.side_effect = [mock_worker1, mock_worker2]
|
||||
mock_send_handle: mock.Mock) -> None:
|
||||
acceptor1 = mock.MagicMock()
|
||||
acceptor2 = mock.MagicMock()
|
||||
mock_acceptor.side_effect = [acceptor1, acceptor2]
|
||||
|
||||
num_workers = 2
|
||||
sock = mock_socket.return_value
|
||||
work_klass = mock.MagicMock()
|
||||
flags = Flags(num_workers=2)
|
||||
acceptor = AcceptorPool(flags=flags, work_klass=work_klass)
|
||||
|
||||
acceptor.setup()
|
||||
pool = AcceptorPool(flags=flags, work_klass=work_klass)
|
||||
pool.setup()
|
||||
mock_send_handle.assert_called()
|
||||
|
||||
work_klass.assert_not_called()
|
||||
mock_socket.assert_called_with(
|
||||
socket.AF_INET6 if acceptor.flags.hostname.version == 6 else socket.AF_INET,
|
||||
socket.AF_INET6 if pool.flags.hostname.version == 6 else socket.AF_INET,
|
||||
socket.SOCK_STREAM
|
||||
)
|
||||
sock.setsockopt.assert_called_with(
|
||||
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind.assert_called_with(
|
||||
(str(acceptor.flags.hostname), acceptor.flags.port))
|
||||
sock.listen.assert_called_with(acceptor.flags.backlog)
|
||||
(str(pool.flags.hostname), pool.flags.port))
|
||||
sock.listen.assert_called_with(pool.flags.backlog)
|
||||
sock.setblocking.assert_called_with(False)
|
||||
|
||||
self.assertTrue(mock_pipe.call_count, num_workers)
|
||||
self.assertTrue(mock_worker.call_count, num_workers)
|
||||
mock_worker1.start.assert_called()
|
||||
mock_worker1.join.assert_not_called()
|
||||
mock_worker2.start.assert_called()
|
||||
mock_worker2.join.assert_not_called()
|
||||
self.assertTrue(mock_acceptor.call_count, num_workers)
|
||||
acceptor1.start.assert_called()
|
||||
acceptor2.start.assert_called()
|
||||
acceptor1.join.assert_not_called()
|
||||
acceptor2.join.assert_not_called()
|
||||
|
||||
sock.close.assert_called()
|
||||
|
||||
acceptor.shutdown()
|
||||
mock_worker1.join.assert_called()
|
||||
mock_worker2.join.assert_called()
|
||||
pool.shutdown()
|
||||
acceptor1.join.assert_called()
|
||||
acceptor2.join.assert_called()
|
||||
|
|
|
@ -14,6 +14,7 @@ from unittest import mock
|
|||
|
||||
from proxy.common.constants import DEFAULT_HTTP_PORT
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.proxy import HttpProxyPlugin
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.http.exception import HttpProtocolException
|
||||
|
@ -40,7 +41,8 @@ class TestHttpProxyPlugin(unittest.TestCase):
|
|||
}
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
def test_proxy_plugin_initialized(self) -> None:
|
||||
|
|
|
@ -17,6 +17,7 @@ import selectors
|
|||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.http.proxy import HttpProxyPlugin
|
||||
from proxy.http.methods import httpMethods
|
||||
|
@ -78,7 +79,8 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
|
|||
}
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
self.plugin.assert_called()
|
||||
|
|
|
@ -15,15 +15,16 @@ import base64
|
|||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
from proxy.common.version import __version__
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.common.utils import bytes_
|
||||
from proxy.common.constants import CRLF
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.parser import HttpParser
|
||||
from proxy.http.proxy import HttpProxyPlugin
|
||||
from proxy.http.parser import httpParserStates, httpParserTypes
|
||||
from proxy.http.exception import ProxyAuthenticationFailed, ProxyConnectionFailed
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.common.version import __version__
|
||||
|
||||
|
||||
class TestHttpProtocolHandler(unittest.TestCase):
|
||||
|
@ -44,7 +45,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
|
||||
self.mock_selector = mock_selector
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr), flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
|
||||
|
@ -175,7 +176,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
flags.plugins = Flags.load_plugins(
|
||||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr), flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
self._conn.recv.return_value = CRLF.join([
|
||||
b'GET http://abhinavsingh.com HTTP/1.1',
|
||||
|
@ -208,7 +209,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, addr=self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr), flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
assert self.http_server_port is not None
|
||||
|
||||
|
@ -256,7 +257,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr), flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
assert self.http_server_port is not None
|
||||
|
|
|
@ -16,6 +16,7 @@ import selectors
|
|||
from unittest import mock
|
||||
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.http.parser import httpParserStates
|
||||
from proxy.common.utils import build_http_response, build_http_request, bytes_, text_
|
||||
|
@ -36,7 +37,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
self.flags.plugins = Flags.load_plugins(
|
||||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
@mock.patch('selectors.DefaultSelector')
|
||||
|
@ -96,7 +98,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
flags.plugins = Flags.load_plugins(
|
||||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
self._conn.recv.return_value = CRLF.join([
|
||||
b'GET /hello HTTP/1.1',
|
||||
|
@ -147,7 +150,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
self.protocol_handler.run_once()
|
||||
|
@ -194,7 +198,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin')
|
||||
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
self.protocol_handler.run_once()
|
||||
|
@ -213,7 +218,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]}
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
plugin.assert_called()
|
||||
with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once:
|
||||
|
@ -228,7 +234,8 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,'
|
||||
b'proxy.http.server.HttpWebServerPacFilePlugin')
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=flags)
|
||||
self.protocol_handler.initialize()
|
||||
self._conn.recv.return_value = CRLF.join([
|
||||
b'GET / HTTP/1.1',
|
||||
|
|
|
@ -17,6 +17,7 @@ from unittest import mock
|
|||
from typing import cast
|
||||
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.http.proxy import HttpProxyPlugin
|
||||
from proxy.common.utils import build_http_request, bytes_, build_http_response
|
||||
|
@ -51,7 +52,8 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
|
|||
}
|
||||
self._conn = mock_fromfd.return_value
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr),
|
||||
flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, cast
|
|||
from proxy.common.utils import bytes_
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.common.utils import build_http_request, build_http_response
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.http.codes import httpStatusCodes
|
||||
from proxy.http.methods import httpMethods
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
|
@ -66,7 +67,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
|
|||
self._conn = mock.MagicMock(spec=socket.socket)
|
||||
mock_fromfd.return_value = self._conn
|
||||
self.protocol_handler = HttpProtocolHandler(
|
||||
self.fileno, self._addr, flags=self.flags)
|
||||
TcpClientConnection(self._conn, self._addr), flags=self.flags)
|
||||
self.protocol_handler.initialize()
|
||||
|
||||
self.server = self.mock_server_conn.return_value
|
||||
|
|
Loading…
Reference in New Issue