Convert HttpProxy and HttpWebServer itself into plugins.

Load plugins during test execution

Further decouple proxy/webserver logic outside of HttpProtocolHandler.

Per connection plugin instances to avoid locks

Handle BrokenPipeError and teardown if read_from_descriptors return True
This commit is contained in:
Abhinav Singh 2019-08-20 16:10:05 -07:00
parent 8732cb7151
commit 6ea42b0dd9
4 changed files with 437 additions and 258 deletions

View File

@ -1,30 +1,38 @@
from urllib import parse as urlparse from urllib import parse as urlparse
from proxy import HttpProxyPlugin, ProxyRejectRequest from proxy import HttpProtocolBasePlugin, ProxyRequestRejected
class RedirectToCustomServerPlugin(HttpProxyPlugin): class RedirectToCustomServerPlugin(HttpProtocolBasePlugin):
"""Modifies client request to redirect all incoming requests to a fixed server address.""" """Modifies client request to redirect all incoming requests to a fixed server address."""
def __init__(self): def __init__(self):
super(RedirectToCustomServerPlugin, self).__init__() super(RedirectToCustomServerPlugin, self).__init__()
def handle_request(self, request): def on_request_complete(self):
if request.method != 'CONNECT': if self.request.method != 'CONNECT':
request.url = urlparse.urlsplit(b'http://localhost:9999') self.request.url = urlparse.urlsplit(b'http://localhost:9999')
return request
class FilterByTargetDomainPlugin(HttpProxyPlugin): class FilterByTargetDomainPlugin(HttpProtocolBasePlugin):
"""Only accepts specific requests dropping all other requests.""" """Only accepts specific requests dropping all other requests."""
def __init__(self): def __init__(self):
super(FilterByTargetDomainPlugin, self).__init__() super(FilterByTargetDomainPlugin, self).__init__()
self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443'] self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443']
def handle_request(self, request): def on_request_complete(self):
# TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes # TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes
# are not consistent between CONNECT and non-CONNECT requests. # are not consistent between CONNECT and non-CONNECT requests.
if (request.method != b'CONNECT' and request.url.hostname not in self.allowed_domains) or \ if (self.request.method != b'CONNECT' and self.request.url.hostname not in self.allowed_domains) or \
(request.method == b'CONNECT' and request.url.path not in self.allowed_domains): (self.request.method == b'CONNECT' and self.request.url.path not in self.allowed_domains):
raise ProxyRejectRequest(status_code=418, body='I\'m a tea pot') raise ProxyRequestRejected(status_code=418, body='I\'m a tea pot')
return request
class SaveHttpResponses(HttpProtocolBasePlugin):
"""Saves Http Responses locally on disk."""
def __init__(self):
super(SaveHttpResponses, self).__init__()
def handle_response_chunk(self, chunk):
return chunk

614
proxy.py
View File

@ -22,6 +22,7 @@ import socket
import sys import sys
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Dict, List
if os.name != 'nt': if os.name != 'nt':
import resource import resource
@ -53,7 +54,7 @@ else: # pragma: no cover
# Defaults # Defaults
DEFAULT_BACKLOG = 100 DEFAULT_BACKLOG = 100
DEFAULT_BASIC_AUTH = None DEFAULT_BASIC_AUTH = None
DEFAULT_BUFFER_SIZE = 8192 DEFAULT_BUFFER_SIZE = 1024 * 1024
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
DEFAULT_IPV4_HOSTNAME = '127.0.0.1' DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
@ -63,10 +64,12 @@ DEFAULT_IPV4 = False
DEFAULT_LOG_LEVEL = 'INFO' DEFAULT_LOG_LEVEL = 'INFO'
DEFAULT_OPEN_FILE_LIMIT = 1024 DEFAULT_OPEN_FILE_LIMIT = 1024
DEFAULT_PAC_FILE = None DEFAULT_PAC_FILE = None
DEFAULT_PAC_FILE_URL_PATH = '/'
DEFAULT_NUM_WORKERS = 0 DEFAULT_NUM_WORKERS = 0
DEFAULT_PLUGINS = [] DEFAULT_PLUGINS = {}
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s' DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
# Set to True if under test
UNDER_TEST = False UNDER_TEST = False
@ -91,7 +94,7 @@ def bytes_(s, encoding='utf-8', errors='strict'): # pragma: no cover
version = bytes_(__version__) version = bytes_(__version__)
CRLF, COLON, SP = b'\r\n', b':', b' ' CRLF, COLON, WHITESPACE = b'\r\n', b':', b' '
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([ PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([
@ -123,31 +126,31 @@ class ChunkParser(object):
self.chunk = b'' # Partial chunk received self.chunk = b'' # Partial chunk received
self.size = None # Expected size of next following chunk self.size = None # Expected size of next following chunk
def parse(self, data): def parse(self, raw):
more = True if len(data) > 0 else False more = True if len(raw) > 0 else False
while more: while more:
more, data = self.process(data) more, raw = self.process(raw)
def process(self, data): def process(self, raw):
if self.state == ChunkParser.states.WAITING_FOR_SIZE: if self.state == ChunkParser.states.WAITING_FOR_SIZE:
# Consume prior chunk in buffer # Consume prior chunk in buffer
# in case chunk size without CRLF was received # in case chunk size without CRLF was received
data = self.chunk + data raw = self.chunk + raw
self.chunk = b'' self.chunk = b''
# Extract following chunk data size # Extract following chunk data size
line, data = HttpParser.split(data) line, raw = HttpParser.split(raw)
if not line: # CRLF not received if not line: # CRLF not received
self.chunk = data self.chunk = raw
data = b'' raw = b''
else: else:
self.size = int(line, 16) self.size = int(line, 16)
self.state = ChunkParser.states.WAITING_FOR_DATA self.state = ChunkParser.states.WAITING_FOR_DATA
elif self.state == ChunkParser.states.WAITING_FOR_DATA: elif self.state == ChunkParser.states.WAITING_FOR_DATA:
remaining = self.size - len(self.chunk) remaining = self.size - len(self.chunk)
self.chunk += data[:remaining] self.chunk += raw[:remaining]
data = data[remaining:] raw = raw[remaining:]
if len(self.chunk) == self.size: if len(self.chunk) == self.size:
data = data[len(CRLF):] raw = raw[len(CRLF):]
self.body += self.chunk self.body += self.chunk
if self.size == 0: if self.size == 0:
self.state = ChunkParser.states.COMPLETE self.state = ChunkParser.states.COMPLETE
@ -155,7 +158,7 @@ class ChunkParser(object):
self.state = ChunkParser.states.WAITING_FOR_SIZE self.state = ChunkParser.states.WAITING_FOR_SIZE
self.chunk = b'' self.chunk = b''
self.size = None self.size = None
return len(data) > 0, data return len(raw) > 0, raw
class HttpParser(object): class HttpParser(object):
@ -179,7 +182,7 @@ class HttpParser(object):
self.type = parser_type self.type = parser_type
self.state = HttpParser.states.INITIALIZED self.state = HttpParser.states.INITIALIZED
self.raw = b'' self.bytes = b''
self.buffer = b'' self.buffer = b''
self.headers = dict() self.headers = dict()
@ -193,22 +196,38 @@ class HttpParser(object):
self.chunk_parser = None self.chunk_parser = None
# This cleans up developer APIs as Python urlparse.urlsplit behaves differently
# for incoming proxy request and incoming web request. Web request is the one
# which is broken.
self.host = None
self.port = None
def set_host_port(self):
if self.type == HttpParser.types.REQUEST_PARSER:
if self.method == b'CONNECT':
self.host, self.port = self.url.path.split(COLON)
elif self.url:
self.host, self.port = self.url.hostname, self.url.port \
if self.url.port else 80
else:
raise Exception('Invalid request\n%s' % self.bytes)
def is_chunked_encoded_response(self): def is_chunked_encoded_response(self):
return self.type == HttpParser.types.RESPONSE_PARSER and \ return self.type == HttpParser.types.RESPONSE_PARSER and \
b'transfer-encoding' in self.headers and \ b'transfer-encoding' in self.headers and \
self.headers[b'transfer-encoding'][1].lower() == b'chunked' self.headers[b'transfer-encoding'][1].lower() == b'chunked'
def parse(self, data): def parse(self, raw):
self.raw += data self.bytes += raw
data = self.buffer + data raw = self.buffer + raw
self.buffer = b'' self.buffer = b''
more = True if len(data) > 0 else False more = True if len(raw) > 0 else False
while more: while more:
more, data = self.process(data) more, raw = self.process(raw)
self.buffer = data self.buffer = raw
def process(self, data): def process(self, raw):
if self.state in (HttpParser.states.HEADERS_COMPLETE, if self.state in (HttpParser.states.HEADERS_COMPLETE,
HttpParser.states.RCVING_BODY, HttpParser.states.RCVING_BODY,
HttpParser.states.COMPLETE) and \ HttpParser.states.COMPLETE) and \
@ -218,22 +237,22 @@ class HttpParser(object):
if b'content-length' in self.headers: if b'content-length' in self.headers:
self.state = HttpParser.states.RCVING_BODY self.state = HttpParser.states.RCVING_BODY
self.body += data self.body += raw
if len(self.body) >= int(self.headers[b'content-length'][1]): if len(self.body) >= int(self.headers[b'content-length'][1]):
self.state = HttpParser.states.COMPLETE self.state = HttpParser.states.COMPLETE
elif self.is_chunked_encoded_response(): elif self.is_chunked_encoded_response():
if not self.chunk_parser: if not self.chunk_parser:
self.chunk_parser = ChunkParser() self.chunk_parser = ChunkParser()
self.chunk_parser.parse(data) self.chunk_parser.parse(raw)
if self.chunk_parser.state == ChunkParser.states.COMPLETE: if self.chunk_parser.state == ChunkParser.states.COMPLETE:
self.body = self.chunk_parser.body self.body = self.chunk_parser.body
self.state = HttpParser.states.COMPLETE self.state = HttpParser.states.COMPLETE
return False, b'' return False, b''
line, data = HttpParser.split(data) line, raw = HttpParser.split(raw)
if line is False: if line is False:
return line, data return line, raw
if self.state == HttpParser.states.INITIALIZED: if self.state == HttpParser.states.INITIALIZED:
self.process_line(line) self.process_line(line)
@ -245,7 +264,7 @@ class HttpParser(object):
if self.state == HttpParser.states.LINE_RCVD and \ if self.state == HttpParser.states.LINE_RCVD and \
self.type == HttpParser.types.REQUEST_PARSER and \ self.type == HttpParser.types.REQUEST_PARSER and \
self.method == b'CONNECT' and \ self.method == b'CONNECT' and \
data == CRLF: raw == CRLF:
self.state = HttpParser.states.COMPLETE self.state = HttpParser.states.COMPLETE
# When raw request has ended with \r\n\r\n and no more http headers are expected # When raw request has ended with \r\n\r\n and no more http headers are expected
@ -254,7 +273,7 @@ class HttpParser(object):
elif self.state == HttpParser.states.HEADERS_COMPLETE and \ elif self.state == HttpParser.states.HEADERS_COMPLETE and \
self.type == HttpParser.types.REQUEST_PARSER and \ self.type == HttpParser.types.REQUEST_PARSER and \
self.method != b'POST' and \ self.method != b'POST' and \
self.raw.endswith(CRLF * 2): self.bytes.endswith(CRLF * 2):
self.state = HttpParser.states.COMPLETE self.state = HttpParser.states.COMPLETE
elif self.state == HttpParser.states.HEADERS_COMPLETE and \ elif self.state == HttpParser.states.HEADERS_COMPLETE and \
self.type == HttpParser.types.REQUEST_PARSER and \ self.type == HttpParser.types.REQUEST_PARSER and \
@ -262,13 +281,13 @@ class HttpParser(object):
(b'content-length' not in self.headers or (b'content-length' not in self.headers or
(b'content-length' in self.headers and (b'content-length' in self.headers and
int(self.headers[b'content-length'][1]) == 0)) and \ int(self.headers[b'content-length'][1]) == 0)) and \
self.raw.endswith(CRLF * 2): self.bytes.endswith(CRLF * 2):
self.state = HttpParser.states.COMPLETE self.state = HttpParser.states.COMPLETE
return len(data) > 0, data return len(raw) > 0, raw
def process_line(self, data): def process_line(self, raw):
line = data.split(SP) line = raw.split(WHITESPACE)
if self.type == HttpParser.types.REQUEST_PARSER: if self.type == HttpParser.types.REQUEST_PARSER:
self.method = line[0].upper() self.method = line[0].upper()
self.url = urlparse.urlsplit(line[1]) self.url = urlparse.urlsplit(line[1])
@ -277,17 +296,18 @@ class HttpParser(object):
self.version = line[0] self.version = line[0]
self.code = line[1] self.code = line[1]
self.reason = b' '.join(line[2:]) self.reason = b' '.join(line[2:])
self.set_host_port()
self.state = HttpParser.states.LINE_RCVD self.state = HttpParser.states.LINE_RCVD
def process_header(self, data): def process_header(self, raw):
if len(data) == 0: if len(raw) == 0:
if self.state == HttpParser.states.RCVING_HEADERS: if self.state == HttpParser.states.RCVING_HEADERS:
self.state = HttpParser.states.HEADERS_COMPLETE self.state = HttpParser.states.HEADERS_COMPLETE
elif self.state == HttpParser.states.LINE_RCVD: elif self.state == HttpParser.states.LINE_RCVD:
self.state = HttpParser.states.RCVING_HEADERS self.state = HttpParser.states.RCVING_HEADERS
else: else:
self.state = HttpParser.states.RCVING_HEADERS self.state = HttpParser.states.RCVING_HEADERS
parts = data.split(COLON) parts = raw.split(COLON)
key = parts[0].strip() key = parts[0].strip()
value = COLON.join(parts[1:]).strip() value = COLON.join(parts[1:]).strip()
self.headers[key.lower()] = (key, value) self.headers[key.lower()] = (key, value)
@ -331,13 +351,23 @@ class HttpParser(object):
return k + b': ' + v return k + b': ' + v
@staticmethod @staticmethod
def split(data): def split(raw):
pos = data.find(CRLF) pos = raw.find(CRLF)
if pos == -1: if pos == -1:
return False, data return False, raw
line = data[:pos] line = raw[:pos]
data = data[pos + len(CRLF):] raw = raw[pos + len(CRLF):]
return line, data return line, raw
###################################################################################
# HttpParser was originally written to parse the incoming raw Http requests.
# Since request / response objects passed to HttpProtocolBasePlugin methods
# are also HttpParser objects, methods below were added to simplify developer API.
####################################################################################
def has_upstream_server(self):
"""Host field SHOULD be None for incoming local WebServer requests."""
return True if self.host is not None else False
class TcpConnection(object): class TcpConnection(object):
@ -488,69 +518,224 @@ class ProxyRequestRejected(ProxyError):
return CRLF.join(pkt) if len(pkt) > 0 else None return CRLF.join(pkt) if len(pkt) > 0 else None
class HttpProxyPlugin(object): class HttpProtocolConfig(object):
"""Base HttpProxy Plugin class.""" """Holds various configuration values applicable to HttpProtocolHandler.
def __init__(self):
pass
def handle_request(self, request):
"""Handle client request (HttpParser).
Return optionally modified client request (HttpParser) object.
"""
return request
def handle_response(self, data):
"""Handle data chunks as received from the server."""
return data
class HttpProxyConfig(object):
"""Holds various configuration values applicable to HttpProxy.
This config class helps us avoid passing around bunch of key/value pairs across methods. This config class helps us avoid passing around bunch of key/value pairs across methods.
""" """
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE, def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, plugins=None): client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE,
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None):
self.auth_code = auth_code self.auth_code = auth_code
self.server_recvbuf_size = server_recvbuf_size self.server_recvbuf_size = server_recvbuf_size
self.client_recvbuf_size = client_recvbuf_size self.client_recvbuf_size = client_recvbuf_size
self.pac_file = pac_file self.pac_file = pac_file
self.pac_file_url_path = pac_file_url_path
if plugins is None: if plugins is None:
plugins = DEFAULT_PLUGINS plugins = DEFAULT_PLUGINS
self.plugins = plugins self.plugins: Dict = plugins
class HttpProxy(threading.Thread): class HttpProtocolBasePlugin(object):
"""HTTP proxy implementation. """Base HttpProtocolHandler Plugin class.
Accepts `Client` connection object and act as a proxy between client and server. Implement various lifecycle event methods to customize behavior."""
"""
def __init__(self, client, config=None):
super(HttpProxy, self).__init__()
self.start_time = self.now()
self.last_activity = self.start_time
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
self.config = config
self.client = client self.client = client
self.server = None self.request = request
self.config = config if config else HttpProxyConfig()
self.request = HttpParser(HttpParser.types.REQUEST_PARSER) def name(self) -> str:
"""A unique name for your plugin.
Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__
def get_descriptors(self):
return [], [], []
def flush_to_descriptors(self, w):
pass
def read_from_descriptors(self, r):
pass
def on_client_data(self, raw: bytes):
return raw
def on_request_complete(self):
"""Called right after client request parser has reached COMPLETE state."""
pass
def handle_response_chunk(self, chunk: bytes):
"""Handle data chunks as received from the server.
Return optionally modified chunk to return back to client."""
return chunk
def access_log(self):
pass
def on_client_connection_close(self):
pass
class HttpProxyPlugin(HttpProtocolBasePlugin):
"""HttpProtocolHandler plugin which implements HttpProxy specifications."""
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
super(HttpProxyPlugin, self).__init__(config, client, request)
self.server = None
self.response = HttpParser(HttpParser.types.RESPONSE_PARSER) self.response = HttpParser(HttpParser.types.RESPONSE_PARSER)
@staticmethod def get_descriptors(self):
def now(): if not self.request.has_upstream_server():
return datetime.datetime.utcnow() return [], [], []
def connection_inactive_for(self): r, w, x = [], [], []
return (self.now() - self.last_activity).seconds if self.server and not self.server.closed:
r.append(self.server.conn)
if self.server and not self.server.closed and self.server.has_buffer():
w.append(self.server.conn)
return r, w, x
def is_connection_inactive(self): def flush_to_descriptors(self, w):
return self.connection_inactive_for() > 30 if not self.request.has_upstream_server():
return
if self.server and not self.server.closed and self.server.conn in w:
logger.debug('Server is ready for writes, flushing server buffer')
self.server.flush()
def read_from_descriptors(self, r):
if not self.request.has_upstream_server():
return
if self.server and not self.server.closed and self.server.conn in r:
logger.debug('Server is ready for reads, reading')
raw = self.server.recv(self.config.server_recvbuf_size)
# self.last_activity = HttpProtocolHandler.now()
if not raw:
logger.debug('Server closed connection, tearing down...')
return True
# parse incoming response packet
# only for non-https requests
if not self.request.method == b'CONNECT':
self.response.parse(raw)
else:
# Only purpose of increasing memory footprint is to print response length in access log
# Not worth it? Optimize to only persist lengths?
self.response.bytes += raw
# queue raw data for client
self.client.queue(raw)
def on_client_connection_close(self):
if not self.request.has_upstream_server():
return
if self.server:
logger.debug(
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
if not self.server.closed:
self.server.close()
def on_client_data(self, raw):
if not self.request.has_upstream_server():
return raw
if self.server and not self.server.closed:
self.server.queue(raw)
return None
else:
return raw
def on_request_complete(self):
if not self.request.has_upstream_server():
return
self.authenticate(self.request.headers)
self.connect_upstream(self.request.host, self.request.port)
# for http connect methods (https requests)
# queue appropriate response for client
# notifying about established connection
if self.request.method == b'CONNECT':
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
# for general http requests, re-build request packet
# and queue for the server with appropriate headers
else:
self.server.queue(self.request.build(
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
b'keep-alive'],
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
))
def access_log(self):
if not self.request.has_upstream_server():
return
host, port = self.server.addr if self.server else (None, None)
if self.request.method == b'CONNECT':
logger.info(
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(host),
text_(port), len(self.response.bytes)))
elif self.request.method:
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
len(self.response.bytes)))
def authenticate(self, headers):
if self.config.auth_code:
if b'proxy-authorization' not in headers or \
headers[b'proxy-authorization'][1] != self.config.auth_code:
raise ProxyAuthenticationFailed()
def connect_upstream(self, host, port):
self.server = TcpServerConnection(host, port)
try:
logger.debug('Connecting to upstream %s:%s' % (host, port))
self.server.connect()
logger.debug('Connected to upstream %s:%s' % (host, port))
except Exception as e: # TimeoutError, socket.gaierror
self.server.closed = True
raise ProxyConnectionFailed(host, port, repr(e))
class HttpWebServerPlugin(HttpProtocolBasePlugin):
"""HttpProtocolHandler plugin which handles incoming requests to local webserver."""
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
super(HttpWebServerPlugin, self).__init__(config, client, request)
def on_request_complete(self):
if self.request.has_upstream_server():
return
if self.config.pac_file and \
self.request.url.path == self.config.pac_file_url_path:
self.serve_pac_file()
else:
self.client.queue(CRLF.join([
b'HTTP/1.1 200 OK',
b'Server: proxy.py v%s' % version,
b'Content-Length: 42',
b'Connection: Close',
CRLF,
b'https://github.com/abhinavsingh/proxy.py'
]))
self.client.flush()
return True
def access_log(self):
if self.request.has_upstream_server():
return
logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(self.request.build_url())))
def serve_pac_file(self): def serve_pac_file(self):
self.client.queue(PAC_FILE_RESPONSE_PREFIX) self.client.queue(PAC_FILE_RESPONSE_PREFIX)
@ -563,129 +748,103 @@ class HttpProxy(threading.Thread):
self.client.queue(self.config.pac_file) self.client.queue(self.config.pac_file)
self.client.flush() self.client.flush()
def access_log(self):
host, port = self.server.addr if self.server else (None, None) class HttpProtocolHandler(threading.Thread):
if self.request.method == b'CONNECT': """HTTP, HTTPS, HTTP2, WebSockets protocol handler.
logger.info(
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1], Accepts `Client` connection object, manages plugin invocations.
text_(self.request.method), text_(host), """
text_(port), len(self.response.raw)))
elif self.request.method: def __init__(self, client, config=None):
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % ( super(HttpProtocolHandler, self).__init__()
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason), self.start_time = self.now()
len(self.response.raw))) self.last_activity = self.start_time
self.client = client
self.config = config if config else HttpProtocolConfig()
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
self.plugins: Dict[str, HttpProtocolBasePlugin] = {}
for klass in self.config.plugins:
instance = klass(self.config, self.client, self.request)
self.plugins[instance.name()] = instance
@staticmethod
def now():
return datetime.datetime.utcnow()
def connection_inactive_for(self):
return (self.now() - self.last_activity).seconds
def is_connection_inactive(self):
return self.connection_inactive_for() > 30
def run_once(self): def run_once(self):
"""Returns True if proxy must teardown.""" """Returns True if proxy must teardown."""
# Prepare list of descriptors # Prepare list of descriptors
rlist, wlist, xlist = [self.client.conn], [], [] read_desc, write_desc, err_desc = [self.client.conn], [], []
if self.client.has_buffer(): if self.client.has_buffer():
wlist.append(self.client.conn) write_desc.append(self.client.conn)
if self.server and not self.server.closed:
rlist.append(self.server.conn)
if self.server and not self.server.closed and self.server.has_buffer():
wlist.append(self.server.conn)
r, w, x = select.select(rlist, wlist, xlist, 1) for plugin in self.plugins.values():
plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors()
read_desc += plugin_read_desc
write_desc += plugin_write_desc
err_desc += plugin_err_desc
# Flush buffer from ready to write sockets readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1)
if self.client.conn in w:
# Flush buffer for ready to write sockets
if self.client.conn in writable:
logger.debug('Client is ready for writes, flushing client buffer') logger.debug('Client is ready for writes, flushing client buffer')
self.client.flush() try:
if self.server and not self.server.closed and self.server.conn in w: self.client.flush()
logger.debug('Server is ready for writes, flushing server buffer') except BrokenPipeError:
self.server.flush() logging.error('BrokenPipeError when flushing buffer for client')
return True
for plugin in self.plugins.values():
plugin.flush_to_descriptors(writable)
# Read from ready to read sockets # Read from ready to read sockets
if self.client.conn in r: if self.client.conn in readable:
logger.debug('Client is ready for reads, reading') logger.debug('Client is ready for reads, reading')
data = self.client.recv(self.config.client_recvbuf_size) client_data = self.client.recv(self.config.client_recvbuf_size)
self.last_activity = self.now() self.last_activity = self.now()
if not data: if not client_data:
logger.debug('Client closed connection, tearing down...') logger.debug('Client closed connection, tearing down...')
return True return True
try:
# Once we have connection to the server plugin_index = 0
# we don't parse the http request packets plugins = list(self.plugins.values())
# any further, instead just pipe incoming while plugin_index < len(plugins) and client_data:
# data from client to server client_data = plugins[plugin_index].on_client_data(client_data)
if self.server and not self.server.closed: plugin_index += 1
self.server.queue(data)
else: if client_data:
try:
# Parse http request # Parse http request
self.request.parse(data) self.request.parse(client_data)
if self.request.state == HttpParser.states.COMPLETE: if self.request.state == HttpParser.states.COMPLETE:
logger.debug('Request parser is in state complete') for plugin in self.plugins.values():
# TODO: Cleanup by not returning True for teardown cases
plugin_response = plugin.on_request_complete()
if type(plugin_response) is bool:
return True
except ProxyError as e: # ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected
# logger.exception(e)
response = e.response(self.request)
if response:
self.client.queue(response)
# But is client also ready for writes?
self.client.flush()
raise e
if self.config.auth_code: for plugin in self.plugins.values():
if b'proxy-authorization' not in self.request.headers or \ teardown = plugin.read_from_descriptors(readable)
self.request.headers[b'proxy-authorization'][1] != self.config.auth_code: if teardown:
raise ProxyAuthenticationFailed()
# Invoke HttpProxyPlugin.handle_request
for plugin in self.config.plugins:
self.request = plugin.handle_request(self.request)
if self.request.method == b'CONNECT':
host, port = self.request.url.path.split(COLON)
elif self.request.url:
host, port = self.request.url.hostname, self.request.url.port \
if self.request.url.port else 80
else:
raise Exception('Invalid request\n%s' % self.request.raw)
if host is None and self.config.pac_file:
self.serve_pac_file()
return True
self.server = TcpServerConnection(host, port)
try:
logger.debug('Connecting to server %s:%s' % (host, port))
self.server.connect()
logger.debug('Connected to server %s:%s' % (host, port))
except Exception as e: # TimeoutError, socket.gaierror
self.server.closed = True
raise ProxyConnectionFailed(host, port, repr(e))
# for http connect methods (https requests)
# queue appropriate response for client
# notifying about established connection
if self.request.method == b'CONNECT':
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
# for usual http requests, re-build request packet
# and queue for the server with appropriate headers
else:
self.server.queue(self.request.build(
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
b'keep-alive'],
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
))
except (ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected) as e:
# logger.exception(e)
response = e.response(self.request)
if response:
self.client.queue(response)
# But is client also ready for writes?
self.client.flush()
raise e
if self.server and not self.server.closed and self.server.conn in r:
logger.debug('Server is ready for reads, reading')
data = self.server.recv(self.config.server_recvbuf_size)
self.last_activity = self.now()
if not data:
logger.debug('Server closed connection, tearing down...')
return True return True
# parse incoming response packet
# only for non-https requests
if not self.request.method == b'CONNECT':
self.response.parse(data)
else:
# Only purpose of increasing memory footprint is to print response length in access log
# Not worth it? Optimize to only persist lengths?
self.response.raw += data
# queue data for client
self.client.queue(data)
# Teardown if client buffer is empty and connection is inactive # Teardown if client buffer is empty and connection is inactive
if self.client.buffer_size() == 0: if self.client.buffer_size() == 0:
@ -706,16 +865,17 @@ class HttpProxy(threading.Thread):
except Exception as e: except Exception as e:
logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e)) logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e))
finally: finally:
for plugin in self.plugins.values():
plugin.access_log()
self.client.close() self.client.close()
logger.debug( logger.debug('Closed client connection with pending '
'Closed client connection with pending client buffer size %d bytes' % self.client.buffer_size()) 'client buffer size %d bytes' % self.client.buffer_size())
if self.server: for plugin in self.plugins.values():
logger.debug( plugin.on_client_connection_close()
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
if not self.server.closed: logger.debug('Closed proxy for connection %r '
self.server.close() 'at address %r' % (self.client.conn, self.client.addr))
self.access_log()
logger.debug('Closed proxy for connection %r at address %r' % (self.client.conn, self.client.addr))
class TcpServer(object): class TcpServer(object):
@ -797,7 +957,7 @@ class MultiCoreRequestDispatcher(TcpServer):
self.workers.append(worker) self.workers.append(worker)
def handle(self, client): def handle(self, client):
self.worker_queue.put((Worker.operations.HTTP_PROXY, client)) self.worker_queue.put((Worker.operations.HTTP_PROTOCOL, client))
def shutdown(self): def shutdown(self):
logger.info('Shutting down %d workers' % self.num_workers) logger.info('Shutting down %d workers' % self.num_workers)
@ -815,7 +975,7 @@ class Worker(multiprocessing.Process):
""" """
operations = namedtuple('WorkerOperations', ( operations = namedtuple('WorkerOperations', (
'HTTP_PROXY', 'HTTP_PROTOCOL',
'SHUTDOWN', 'SHUTDOWN',
))(1, 2) ))(1, 2)
@ -828,8 +988,8 @@ class Worker(multiprocessing.Process):
while True: while True:
try: try:
op, payload = self.work_queue.get(True, 1) op, payload = self.work_queue.get(True, 1)
if op == Worker.operations.HTTP_PROXY: if op == Worker.operations.HTTP_PROTOCOL:
proxy = HttpProxy(payload, config=self.config) proxy = HttpProtocolHandler(payload, config=self.config)
proxy.setDaemon(True) proxy.setDaemon(True)
proxy.start() proxy.start()
elif op == Worker.operations.SHUTDOWN: elif op == Worker.operations.SHUTDOWN:
@ -852,6 +1012,22 @@ def set_open_file_limit(soft_limit):
logger.debug('Open file descriptor soft limit set to %d' % soft_limit) logger.debug('Open file descriptor soft limit set to %d' % soft_limit)
def load_plugins(plugins) -> List:
p = []
plugins = plugins.split(',')
for plugin in plugins:
plugin = plugin.strip()
if plugin == '':
continue
logging.debug('Loading plugin %s', plugin)
module_name, klass_name = plugin.rsplit('.', 1)
module = importlib.import_module(module_name)
klass = getattr(module, klass_name)
logging.info('%s initialized' % klass.__name__)
p.append(klass)
return p
def init_parser(): def init_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='proxy.py v%s' % __version__, description='proxy.py v%s' % __version__,
@ -882,12 +1058,14 @@ def init_parser():
parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT, parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT,
help='Default: 1024. Maximum number of files (TCP connections) ' help='Default: 1024. Maximum number of files (TCP connections) '
'that proxy.py can open concurrently.') 'that proxy.py can open concurrently.')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Default: 8899. Server port.')
parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE, parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE,
help='A file (Proxy Auto Configuration) or string to serve when ' help='A file (Proxy Auto Configuration) or string to serve when '
'the server receives a direct file request.') 'the server receives a direct file request.')
parser.add_argument('--plugins', type=str, default=None, help='Comma separated plugins') parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH,
help='Web server path to serve the PAC file.')
parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Default: 8899. Server port.')
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE, parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
help='Default: 8 KB. Maximum amount of data received from the ' help='Default: 8 KB. Maximum amount of data received from the '
'server in a single recv() operation. Bump this ' 'server in a single recv() operation. Bump this '
@ -898,18 +1076,6 @@ def init_parser():
return parser return parser
def load_plugins(plugins):
p = []
plugins = plugins.split(',')
for plugin in plugins:
module_name, klass_name = plugin.rsplit('.', 1)
module = importlib.import_module(module_name)
klass = getattr(module, klass_name)
logging.info('%s initialized' % klass.__name__)
p.append(klass())
return p
def main(): def main():
if not PY3 and not UNDER_TEST: if not PY3 and not UNDER_TEST:
print( print(
@ -924,7 +1090,6 @@ def main():
parser = init_parser() parser = init_parser()
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
logging.basicConfig(level=getattr(logging, logging.basicConfig(level=getattr(logging,
{ {
'D': 'DEBUG', 'D': 'DEBUG',
@ -934,7 +1099,6 @@ def main():
'C': 'CRITICAL' 'C': 'CRITICAL'
}[args.log_level.upper()[0]]), }[args.log_level.upper()[0]]),
format=args.log_format) format=args.log_format)
plugins = load_plugins(args.plugins) if args.plugins else DEFAULT_PLUGINS
try: try:
set_open_file_limit(args.open_file_limit) set_open_file_limit(args.open_file_limit)
@ -943,16 +1107,18 @@ def main():
if args.basic_auth: if args.basic_auth:
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth)) auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
config = HttpProtocolConfig(auth_code=auth_code,
server_recvbuf_size=args.server_recvbuf_size,
client_recvbuf_size=args.client_recvbuf_size,
pac_file=args.pac_file,
pac_file_url_path=args.pac_file_url_path)
config.plugins = load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,%s' % args.plugins)
server = MultiCoreRequestDispatcher(hostname=args.hostname, server = MultiCoreRequestDispatcher(hostname=args.hostname,
port=args.port, port=args.port,
backlog=args.backlog, backlog=args.backlog,
ipv4=args.ipv4, ipv4=args.ipv4,
num_workers=args.num_workers, num_workers=args.num_workers,
config=HttpProxyConfig(auth_code=auth_code, config=config)
server_recvbuf_size=args.server_recvbuf_size,
client_recvbuf_size=args.client_recvbuf_size,
pac_file=args.pac_file,
plugins=plugins))
server.run() server.run()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@ -1 +0,0 @@
mock==3.0.5

View File

@ -3,9 +3,9 @@
proxy.py proxy.py
~~~~~~~~ ~~~~~~~~
HTTP Proxy Server in Python. HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python.
:copyright: (c) 2013-2018 by Abhinav Singh. :copyright: (c) 2013-2020 by Abhinav Singh.
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
import os import os
@ -21,8 +21,8 @@ from threading import Thread
import proxy import proxy
# logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
# format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s') format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
if sys.version_info[0] == 3: # Python3 specific imports if sys.version_info[0] == 3: # Python3 specific imports
from http.server import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
@ -123,7 +123,7 @@ class TestMultiCoreRequestDispatcher(unittest.TestCase):
tcp_server = None tcp_server = None
tcp_thread = None tcp_thread = None
@mock.patch.object(proxy, 'HttpProxy', side_effect=mock_tcp_proxy_side_effect) @mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect)
def testHttpProxyConnection(self, mock_tcp_proxy): def testHttpProxyConnection(self, mock_tcp_proxy):
try: try:
self.tcp_port = get_available_port() self.tcp_port = get_available_port()
@ -523,6 +523,7 @@ class TestProxy(unittest.TestCase):
http_server = None http_server = None
http_server_port = None http_server_port = None
http_server_thread = None http_server_thread = None
config = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -531,6 +532,8 @@ class TestProxy(unittest.TestCase):
cls.http_server_thread = Thread(target=cls.http_server.serve_forever) cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
cls.http_server_thread.setDaemon(True) cls.http_server_thread.setDaemon(True)
cls.http_server_thread.start() cls.http_server_thread.start()
cls.config = proxy.HttpProtocolConfig()
cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -541,7 +544,7 @@ class TestProxy(unittest.TestCase):
def setUp(self): def setUp(self):
self._conn = MockTcpConnection() self._conn = MockTcpConnection()
self._addr = ('127.0.0.1', 54382) self._addr = ('127.0.0.1', 54382)
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr)) self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config)
@mock.patch('select.select') @mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection') @mock.patch('proxy.TcpServerConnection')
@ -610,7 +613,7 @@ class TestProxy(unittest.TestCase):
proxy.CRLF proxy.CRLF
])) ]))
self.proxy.run_once() self.proxy.run_once()
self.assertFalse(self.proxy.server is None) self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once() mock_server_connection.assert_called_once()
server.connect.assert_called_once() server.connect.assert_called_once()
@ -659,9 +662,10 @@ class TestProxy(unittest.TestCase):
@mock.patch('select.select') @mock.patch('select.select')
def test_proxy_authentication_failed(self, mock_select): def test_proxy_authentication_failed(self, mock_select):
mock_select.return_value = ([self._conn], [], []) mock_select.return_value = ([self._conn], [], [])
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr), config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config=proxy.HttpProxyConfig( config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))) self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([ self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET http://abhinavsingh.com HTTP/1.1', b'GET http://abhinavsingh.com HTTP/1.1',
b'Host: abhinavsingh.com', b'Host: abhinavsingh.com',
@ -673,14 +677,15 @@ class TestProxy(unittest.TestCase):
@mock.patch('select.select') @mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection') @mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select): def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
client = proxy.TcpClientConnection(self._conn, self._addr)
config = proxy.HttpProxyConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
mock_select.return_value = ([self._conn], [], []) mock_select.return_value = ([self._conn], [], [])
server = mock_server_connection.return_value server = mock_server_connection.return_value
server.connect.return_value = True server.connect.return_value = True
self.proxy = proxy.HttpProxy(client, config=config) client = proxy.TcpClientConnection(self._conn, self._addr)
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(client, config=config)
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
self.proxy.run_once() self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED) self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
@ -731,9 +736,10 @@ class TestProxy(unittest.TestCase):
server.connect.return_value = True server.connect.return_value = True
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr), config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config=proxy.HttpProxyConfig( config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))) self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([ self.proxy.client.conn.queue(proxy.CRLF.join([
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
b'Host: localhost:%d' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port,
@ -743,7 +749,7 @@ class TestProxy(unittest.TestCase):
proxy.CRLF proxy.CRLF
])) ]))
self.proxy.run_once() self.proxy.run_once()
self.assertFalse(self.proxy.server is None) self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once() mock_server_connection.assert_called_once()
server.connect.assert_called_once() server.connect.assert_called_once()
@ -780,15 +786,15 @@ class TestWorker(unittest.TestCase):
self.queue = multiprocessing.Queue() self.queue = multiprocessing.Queue()
self.worker = proxy.Worker(self.queue) self.worker = proxy.Worker(self.queue)
@mock.patch('proxy.HttpProxy') @mock.patch('proxy.HttpProtocolHandler')
def test_shutdown_op(self, mock_http_proxy): def test_shutdown_op(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run() # Worker should consume the prior shutdown operation self.worker.run() # Worker should consume the prior shutdown operation
self.assertFalse(mock_http_proxy.called) self.assertFalse(mock_http_proxy.called)
@mock.patch('proxy.HttpProxy') @mock.patch('proxy.HttpProtocolHandler')
def test_spawns_http_proxy_threads(self, mock_http_proxy): def test_spawns_http_proxy_threads(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.HTTP_PROXY, None)) self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None))
self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run() self.worker.run()
self.assertTrue(mock_http_proxy.called) self.assertTrue(mock_http_proxy.called)