Convert HttpProxy and HttpWebServer itself into plugins.
Load plugins during test execution Further decouple proxy/webserver logic outside of HttpProtocolHandler. Per connection plugin instances to avoid locks Handle BrokenPipeError and teardown if read_from_descriptors return True
This commit is contained in:
parent
8732cb7151
commit
6ea42b0dd9
|
@ -1,30 +1,38 @@
|
|||
from urllib import parse as urlparse
|
||||
from proxy import HttpProxyPlugin, ProxyRejectRequest
|
||||
from proxy import HttpProtocolBasePlugin, ProxyRequestRejected
|
||||
|
||||
|
||||
class RedirectToCustomServerPlugin(HttpProxyPlugin):
|
||||
class RedirectToCustomServerPlugin(HttpProtocolBasePlugin):
|
||||
"""Modifies client request to redirect all incoming requests to a fixed server address."""
|
||||
|
||||
def __init__(self):
|
||||
super(RedirectToCustomServerPlugin, self).__init__()
|
||||
|
||||
def handle_request(self, request):
|
||||
if request.method != 'CONNECT':
|
||||
request.url = urlparse.urlsplit(b'http://localhost:9999')
|
||||
return request
|
||||
def on_request_complete(self):
|
||||
if self.request.method != 'CONNECT':
|
||||
self.request.url = urlparse.urlsplit(b'http://localhost:9999')
|
||||
|
||||
|
||||
class FilterByTargetDomainPlugin(HttpProxyPlugin):
|
||||
class FilterByTargetDomainPlugin(HttpProtocolBasePlugin):
|
||||
"""Only accepts specific requests dropping all other requests."""
|
||||
|
||||
def __init__(self):
|
||||
super(FilterByTargetDomainPlugin, self).__init__()
|
||||
self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443']
|
||||
|
||||
def handle_request(self, request):
|
||||
def on_request_complete(self):
|
||||
# TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes
|
||||
# are not consistent between CONNECT and non-CONNECT requests.
|
||||
if (request.method != b'CONNECT' and request.url.hostname not in self.allowed_domains) or \
|
||||
(request.method == b'CONNECT' and request.url.path not in self.allowed_domains):
|
||||
raise ProxyRejectRequest(status_code=418, body='I\'m a tea pot')
|
||||
return request
|
||||
if (self.request.method != b'CONNECT' and self.request.url.hostname not in self.allowed_domains) or \
|
||||
(self.request.method == b'CONNECT' and self.request.url.path not in self.allowed_domains):
|
||||
raise ProxyRequestRejected(status_code=418, body='I\'m a tea pot')
|
||||
|
||||
|
||||
class SaveHttpResponses(HttpProtocolBasePlugin):
|
||||
"""Saves Http Responses locally on disk."""
|
||||
|
||||
def __init__(self):
|
||||
super(SaveHttpResponses, self).__init__()
|
||||
|
||||
def handle_response_chunk(self, chunk):
|
||||
return chunk
|
||||
|
|
596
proxy.py
596
proxy.py
|
@ -22,6 +22,7 @@ import socket
|
|||
import sys
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List
|
||||
|
||||
if os.name != 'nt':
|
||||
import resource
|
||||
|
@ -53,7 +54,7 @@ else: # pragma: no cover
|
|||
# Defaults
|
||||
DEFAULT_BACKLOG = 100
|
||||
DEFAULT_BASIC_AUTH = None
|
||||
DEFAULT_BUFFER_SIZE = 8192
|
||||
DEFAULT_BUFFER_SIZE = 1024 * 1024
|
||||
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||
DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
|
||||
|
@ -63,10 +64,12 @@ DEFAULT_IPV4 = False
|
|||
DEFAULT_LOG_LEVEL = 'INFO'
|
||||
DEFAULT_OPEN_FILE_LIMIT = 1024
|
||||
DEFAULT_PAC_FILE = None
|
||||
DEFAULT_PAC_FILE_URL_PATH = '/'
|
||||
DEFAULT_NUM_WORKERS = 0
|
||||
DEFAULT_PLUGINS = []
|
||||
DEFAULT_PLUGINS = {}
|
||||
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Set to True if under test
|
||||
UNDER_TEST = False
|
||||
|
||||
|
||||
|
@ -91,7 +94,7 @@ def bytes_(s, encoding='utf-8', errors='strict'): # pragma: no cover
|
|||
|
||||
|
||||
version = bytes_(__version__)
|
||||
CRLF, COLON, SP = b'\r\n', b':', b' '
|
||||
CRLF, COLON, WHITESPACE = b'\r\n', b':', b' '
|
||||
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
|
||||
|
||||
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([
|
||||
|
@ -123,31 +126,31 @@ class ChunkParser(object):
|
|||
self.chunk = b'' # Partial chunk received
|
||||
self.size = None # Expected size of next following chunk
|
||||
|
||||
def parse(self, data):
|
||||
more = True if len(data) > 0 else False
|
||||
def parse(self, raw):
|
||||
more = True if len(raw) > 0 else False
|
||||
while more:
|
||||
more, data = self.process(data)
|
||||
more, raw = self.process(raw)
|
||||
|
||||
def process(self, data):
|
||||
def process(self, raw):
|
||||
if self.state == ChunkParser.states.WAITING_FOR_SIZE:
|
||||
# Consume prior chunk in buffer
|
||||
# in case chunk size without CRLF was received
|
||||
data = self.chunk + data
|
||||
raw = self.chunk + raw
|
||||
self.chunk = b''
|
||||
# Extract following chunk data size
|
||||
line, data = HttpParser.split(data)
|
||||
line, raw = HttpParser.split(raw)
|
||||
if not line: # CRLF not received
|
||||
self.chunk = data
|
||||
data = b''
|
||||
self.chunk = raw
|
||||
raw = b''
|
||||
else:
|
||||
self.size = int(line, 16)
|
||||
self.state = ChunkParser.states.WAITING_FOR_DATA
|
||||
elif self.state == ChunkParser.states.WAITING_FOR_DATA:
|
||||
remaining = self.size - len(self.chunk)
|
||||
self.chunk += data[:remaining]
|
||||
data = data[remaining:]
|
||||
self.chunk += raw[:remaining]
|
||||
raw = raw[remaining:]
|
||||
if len(self.chunk) == self.size:
|
||||
data = data[len(CRLF):]
|
||||
raw = raw[len(CRLF):]
|
||||
self.body += self.chunk
|
||||
if self.size == 0:
|
||||
self.state = ChunkParser.states.COMPLETE
|
||||
|
@ -155,7 +158,7 @@ class ChunkParser(object):
|
|||
self.state = ChunkParser.states.WAITING_FOR_SIZE
|
||||
self.chunk = b''
|
||||
self.size = None
|
||||
return len(data) > 0, data
|
||||
return len(raw) > 0, raw
|
||||
|
||||
|
||||
class HttpParser(object):
|
||||
|
@ -179,7 +182,7 @@ class HttpParser(object):
|
|||
self.type = parser_type
|
||||
self.state = HttpParser.states.INITIALIZED
|
||||
|
||||
self.raw = b''
|
||||
self.bytes = b''
|
||||
self.buffer = b''
|
||||
|
||||
self.headers = dict()
|
||||
|
@ -193,22 +196,38 @@ class HttpParser(object):
|
|||
|
||||
self.chunk_parser = None
|
||||
|
||||
# This cleans up developer APIs as Python urlparse.urlsplit behaves differently
|
||||
# for incoming proxy request and incoming web request. Web request is the one
|
||||
# which is broken.
|
||||
self.host = None
|
||||
self.port = None
|
||||
|
||||
def set_host_port(self):
|
||||
if self.type == HttpParser.types.REQUEST_PARSER:
|
||||
if self.method == b'CONNECT':
|
||||
self.host, self.port = self.url.path.split(COLON)
|
||||
elif self.url:
|
||||
self.host, self.port = self.url.hostname, self.url.port \
|
||||
if self.url.port else 80
|
||||
else:
|
||||
raise Exception('Invalid request\n%s' % self.bytes)
|
||||
|
||||
def is_chunked_encoded_response(self):
|
||||
return self.type == HttpParser.types.RESPONSE_PARSER and \
|
||||
b'transfer-encoding' in self.headers and \
|
||||
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
|
||||
|
||||
def parse(self, data):
|
||||
self.raw += data
|
||||
data = self.buffer + data
|
||||
def parse(self, raw):
|
||||
self.bytes += raw
|
||||
raw = self.buffer + raw
|
||||
self.buffer = b''
|
||||
|
||||
more = True if len(data) > 0 else False
|
||||
more = True if len(raw) > 0 else False
|
||||
while more:
|
||||
more, data = self.process(data)
|
||||
self.buffer = data
|
||||
more, raw = self.process(raw)
|
||||
self.buffer = raw
|
||||
|
||||
def process(self, data):
|
||||
def process(self, raw):
|
||||
if self.state in (HttpParser.states.HEADERS_COMPLETE,
|
||||
HttpParser.states.RCVING_BODY,
|
||||
HttpParser.states.COMPLETE) and \
|
||||
|
@ -218,22 +237,22 @@ class HttpParser(object):
|
|||
|
||||
if b'content-length' in self.headers:
|
||||
self.state = HttpParser.states.RCVING_BODY
|
||||
self.body += data
|
||||
self.body += raw
|
||||
if len(self.body) >= int(self.headers[b'content-length'][1]):
|
||||
self.state = HttpParser.states.COMPLETE
|
||||
elif self.is_chunked_encoded_response():
|
||||
if not self.chunk_parser:
|
||||
self.chunk_parser = ChunkParser()
|
||||
self.chunk_parser.parse(data)
|
||||
self.chunk_parser.parse(raw)
|
||||
if self.chunk_parser.state == ChunkParser.states.COMPLETE:
|
||||
self.body = self.chunk_parser.body
|
||||
self.state = HttpParser.states.COMPLETE
|
||||
|
||||
return False, b''
|
||||
|
||||
line, data = HttpParser.split(data)
|
||||
line, raw = HttpParser.split(raw)
|
||||
if line is False:
|
||||
return line, data
|
||||
return line, raw
|
||||
|
||||
if self.state == HttpParser.states.INITIALIZED:
|
||||
self.process_line(line)
|
||||
|
@ -245,7 +264,7 @@ class HttpParser(object):
|
|||
if self.state == HttpParser.states.LINE_RCVD and \
|
||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||
self.method == b'CONNECT' and \
|
||||
data == CRLF:
|
||||
raw == CRLF:
|
||||
self.state = HttpParser.states.COMPLETE
|
||||
|
||||
# When raw request has ended with \r\n\r\n and no more http headers are expected
|
||||
|
@ -254,7 +273,7 @@ class HttpParser(object):
|
|||
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||
self.method != b'POST' and \
|
||||
self.raw.endswith(CRLF * 2):
|
||||
self.bytes.endswith(CRLF * 2):
|
||||
self.state = HttpParser.states.COMPLETE
|
||||
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||
|
@ -262,13 +281,13 @@ class HttpParser(object):
|
|||
(b'content-length' not in self.headers or
|
||||
(b'content-length' in self.headers and
|
||||
int(self.headers[b'content-length'][1]) == 0)) and \
|
||||
self.raw.endswith(CRLF * 2):
|
||||
self.bytes.endswith(CRLF * 2):
|
||||
self.state = HttpParser.states.COMPLETE
|
||||
|
||||
return len(data) > 0, data
|
||||
return len(raw) > 0, raw
|
||||
|
||||
def process_line(self, data):
|
||||
line = data.split(SP)
|
||||
def process_line(self, raw):
|
||||
line = raw.split(WHITESPACE)
|
||||
if self.type == HttpParser.types.REQUEST_PARSER:
|
||||
self.method = line[0].upper()
|
||||
self.url = urlparse.urlsplit(line[1])
|
||||
|
@ -277,17 +296,18 @@ class HttpParser(object):
|
|||
self.version = line[0]
|
||||
self.code = line[1]
|
||||
self.reason = b' '.join(line[2:])
|
||||
self.set_host_port()
|
||||
self.state = HttpParser.states.LINE_RCVD
|
||||
|
||||
def process_header(self, data):
|
||||
if len(data) == 0:
|
||||
def process_header(self, raw):
|
||||
if len(raw) == 0:
|
||||
if self.state == HttpParser.states.RCVING_HEADERS:
|
||||
self.state = HttpParser.states.HEADERS_COMPLETE
|
||||
elif self.state == HttpParser.states.LINE_RCVD:
|
||||
self.state = HttpParser.states.RCVING_HEADERS
|
||||
else:
|
||||
self.state = HttpParser.states.RCVING_HEADERS
|
||||
parts = data.split(COLON)
|
||||
parts = raw.split(COLON)
|
||||
key = parts[0].strip()
|
||||
value = COLON.join(parts[1:]).strip()
|
||||
self.headers[key.lower()] = (key, value)
|
||||
|
@ -331,13 +351,23 @@ class HttpParser(object):
|
|||
return k + b': ' + v
|
||||
|
||||
@staticmethod
|
||||
def split(data):
|
||||
pos = data.find(CRLF)
|
||||
def split(raw):
|
||||
pos = raw.find(CRLF)
|
||||
if pos == -1:
|
||||
return False, data
|
||||
line = data[:pos]
|
||||
data = data[pos + len(CRLF):]
|
||||
return line, data
|
||||
return False, raw
|
||||
line = raw[:pos]
|
||||
raw = raw[pos + len(CRLF):]
|
||||
return line, raw
|
||||
|
||||
###################################################################################
|
||||
# HttpParser was originally written to parse the incoming raw Http requests.
|
||||
# Since request / response objects passed to HttpProtocolBasePlugin methods
|
||||
# are also HttpParser objects, methods below were added to simplify developer API.
|
||||
####################################################################################
|
||||
|
||||
def has_upstream_server(self):
|
||||
"""Host field SHOULD be None for incoming local WebServer requests."""
|
||||
return True if self.host is not None else False
|
||||
|
||||
|
||||
class TcpConnection(object):
|
||||
|
@ -488,69 +518,224 @@ class ProxyRequestRejected(ProxyError):
|
|||
return CRLF.join(pkt) if len(pkt) > 0 else None
|
||||
|
||||
|
||||
class HttpProxyPlugin(object):
|
||||
"""Base HttpProxy Plugin class."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def handle_request(self, request):
|
||||
"""Handle client request (HttpParser).
|
||||
|
||||
Return optionally modified client request (HttpParser) object.
|
||||
"""
|
||||
return request
|
||||
|
||||
def handle_response(self, data):
|
||||
"""Handle data chunks as received from the server."""
|
||||
return data
|
||||
|
||||
|
||||
class HttpProxyConfig(object):
|
||||
"""Holds various configuration values applicable to HttpProxy.
|
||||
class HttpProtocolConfig(object):
|
||||
"""Holds various configuration values applicable to HttpProtocolHandler.
|
||||
|
||||
This config class helps us avoid passing around bunch of key/value pairs across methods.
|
||||
"""
|
||||
|
||||
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, plugins=None):
|
||||
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE,
|
||||
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None):
|
||||
self.auth_code = auth_code
|
||||
self.server_recvbuf_size = server_recvbuf_size
|
||||
self.client_recvbuf_size = client_recvbuf_size
|
||||
self.pac_file = pac_file
|
||||
self.pac_file_url_path = pac_file_url_path
|
||||
if plugins is None:
|
||||
plugins = DEFAULT_PLUGINS
|
||||
self.plugins = plugins
|
||||
self.plugins: Dict = plugins
|
||||
|
||||
|
||||
class HttpProxy(threading.Thread):
|
||||
"""HTTP proxy implementation.
|
||||
class HttpProtocolBasePlugin(object):
|
||||
"""Base HttpProtocolHandler Plugin class.
|
||||
|
||||
Accepts `Client` connection object and act as a proxy between client and server.
|
||||
"""
|
||||
|
||||
def __init__(self, client, config=None):
|
||||
super(HttpProxy, self).__init__()
|
||||
|
||||
self.start_time = self.now()
|
||||
self.last_activity = self.start_time
|
||||
Implement various lifecycle event methods to customize behavior."""
|
||||
|
||||
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||
self.config = config
|
||||
self.client = client
|
||||
self.server = None
|
||||
self.config = config if config else HttpProxyConfig()
|
||||
self.request = request
|
||||
|
||||
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
|
||||
def name(self) -> str:
|
||||
"""A unique name for your plugin.
|
||||
|
||||
Defaults to name of the class. This helps plugin developers to directly
|
||||
access a specific plugin by its name."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_descriptors(self):
|
||||
return [], [], []
|
||||
|
||||
def flush_to_descriptors(self, w):
|
||||
pass
|
||||
|
||||
def read_from_descriptors(self, r):
|
||||
pass
|
||||
|
||||
def on_client_data(self, raw: bytes):
|
||||
return raw
|
||||
|
||||
def on_request_complete(self):
|
||||
"""Called right after client request parser has reached COMPLETE state."""
|
||||
pass
|
||||
|
||||
def handle_response_chunk(self, chunk: bytes):
|
||||
"""Handle data chunks as received from the server.
|
||||
|
||||
Return optionally modified chunk to return back to client."""
|
||||
return chunk
|
||||
|
||||
def access_log(self):
|
||||
pass
|
||||
|
||||
def on_client_connection_close(self):
|
||||
pass
|
||||
|
||||
|
||||
class HttpProxyPlugin(HttpProtocolBasePlugin):
|
||||
"""HttpProtocolHandler plugin which implements HttpProxy specifications."""
|
||||
|
||||
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||
super(HttpProxyPlugin, self).__init__(config, client, request)
|
||||
self.server = None
|
||||
self.response = HttpParser(HttpParser.types.RESPONSE_PARSER)
|
||||
|
||||
@staticmethod
|
||||
def now():
|
||||
return datetime.datetime.utcnow()
|
||||
def get_descriptors(self):
|
||||
if not self.request.has_upstream_server():
|
||||
return [], [], []
|
||||
|
||||
def connection_inactive_for(self):
|
||||
return (self.now() - self.last_activity).seconds
|
||||
r, w, x = [], [], []
|
||||
if self.server and not self.server.closed:
|
||||
r.append(self.server.conn)
|
||||
if self.server and not self.server.closed and self.server.has_buffer():
|
||||
w.append(self.server.conn)
|
||||
return r, w, x
|
||||
|
||||
def is_connection_inactive(self):
|
||||
return self.connection_inactive_for() > 30
|
||||
def flush_to_descriptors(self, w):
|
||||
if not self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
if self.server and not self.server.closed and self.server.conn in w:
|
||||
logger.debug('Server is ready for writes, flushing server buffer')
|
||||
self.server.flush()
|
||||
|
||||
def read_from_descriptors(self, r):
|
||||
if not self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
if self.server and not self.server.closed and self.server.conn in r:
|
||||
logger.debug('Server is ready for reads, reading')
|
||||
raw = self.server.recv(self.config.server_recvbuf_size)
|
||||
# self.last_activity = HttpProtocolHandler.now()
|
||||
if not raw:
|
||||
logger.debug('Server closed connection, tearing down...')
|
||||
return True
|
||||
# parse incoming response packet
|
||||
# only for non-https requests
|
||||
if not self.request.method == b'CONNECT':
|
||||
self.response.parse(raw)
|
||||
else:
|
||||
# Only purpose of increasing memory footprint is to print response length in access log
|
||||
# Not worth it? Optimize to only persist lengths?
|
||||
self.response.bytes += raw
|
||||
# queue raw data for client
|
||||
self.client.queue(raw)
|
||||
|
||||
def on_client_connection_close(self):
|
||||
if not self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
if self.server:
|
||||
logger.debug(
|
||||
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
|
||||
if not self.server.closed:
|
||||
self.server.close()
|
||||
|
||||
def on_client_data(self, raw):
|
||||
if not self.request.has_upstream_server():
|
||||
return raw
|
||||
|
||||
if self.server and not self.server.closed:
|
||||
self.server.queue(raw)
|
||||
return None
|
||||
else:
|
||||
return raw
|
||||
|
||||
def on_request_complete(self):
|
||||
if not self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
self.authenticate(self.request.headers)
|
||||
self.connect_upstream(self.request.host, self.request.port)
|
||||
# for http connect methods (https requests)
|
||||
# queue appropriate response for client
|
||||
# notifying about established connection
|
||||
if self.request.method == b'CONNECT':
|
||||
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
# for general http requests, re-build request packet
|
||||
# and queue for the server with appropriate headers
|
||||
else:
|
||||
self.server.queue(self.request.build(
|
||||
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
|
||||
b'keep-alive'],
|
||||
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
|
||||
))
|
||||
|
||||
def access_log(self):
|
||||
if not self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
host, port = self.server.addr if self.server else (None, None)
|
||||
if self.request.method == b'CONNECT':
|
||||
logger.info(
|
||||
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
|
||||
text_(self.request.method), text_(host),
|
||||
text_(port), len(self.response.bytes)))
|
||||
elif self.request.method:
|
||||
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
|
||||
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
|
||||
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
|
||||
len(self.response.bytes)))
|
||||
|
||||
def authenticate(self, headers):
|
||||
if self.config.auth_code:
|
||||
if b'proxy-authorization' not in headers or \
|
||||
headers[b'proxy-authorization'][1] != self.config.auth_code:
|
||||
raise ProxyAuthenticationFailed()
|
||||
|
||||
def connect_upstream(self, host, port):
|
||||
self.server = TcpServerConnection(host, port)
|
||||
try:
|
||||
logger.debug('Connecting to upstream %s:%s' % (host, port))
|
||||
self.server.connect()
|
||||
logger.debug('Connected to upstream %s:%s' % (host, port))
|
||||
except Exception as e: # TimeoutError, socket.gaierror
|
||||
self.server.closed = True
|
||||
raise ProxyConnectionFailed(host, port, repr(e))
|
||||
|
||||
|
||||
class HttpWebServerPlugin(HttpProtocolBasePlugin):
|
||||
"""HttpProtocolHandler plugin which handles incoming requests to local webserver."""
|
||||
|
||||
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||
super(HttpWebServerPlugin, self).__init__(config, client, request)
|
||||
|
||||
def on_request_complete(self):
|
||||
if self.request.has_upstream_server():
|
||||
return
|
||||
|
||||
if self.config.pac_file and \
|
||||
self.request.url.path == self.config.pac_file_url_path:
|
||||
self.serve_pac_file()
|
||||
else:
|
||||
self.client.queue(CRLF.join([
|
||||
b'HTTP/1.1 200 OK',
|
||||
b'Server: proxy.py v%s' % version,
|
||||
b'Content-Length: 42',
|
||||
b'Connection: Close',
|
||||
CRLF,
|
||||
b'https://github.com/abhinavsingh/proxy.py'
|
||||
]))
|
||||
self.client.flush()
|
||||
|
||||
return True
|
||||
|
||||
def access_log(self):
|
||||
if self.request.has_upstream_server():
|
||||
return
|
||||
logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1],
|
||||
text_(self.request.method), text_(self.request.build_url())))
|
||||
|
||||
def serve_pac_file(self):
|
||||
self.client.queue(PAC_FILE_RESPONSE_PREFIX)
|
||||
|
@ -563,105 +748,91 @@ class HttpProxy(threading.Thread):
|
|||
self.client.queue(self.config.pac_file)
|
||||
self.client.flush()
|
||||
|
||||
def access_log(self):
|
||||
host, port = self.server.addr if self.server else (None, None)
|
||||
if self.request.method == b'CONNECT':
|
||||
logger.info(
|
||||
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
|
||||
text_(self.request.method), text_(host),
|
||||
text_(port), len(self.response.raw)))
|
||||
elif self.request.method:
|
||||
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
|
||||
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
|
||||
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
|
||||
len(self.response.raw)))
|
||||
|
||||
class HttpProtocolHandler(threading.Thread):
|
||||
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
||||
|
||||
Accepts `Client` connection object, manages plugin invocations.
|
||||
"""
|
||||
|
||||
def __init__(self, client, config=None):
|
||||
super(HttpProtocolHandler, self).__init__()
|
||||
|
||||
self.start_time = self.now()
|
||||
self.last_activity = self.start_time
|
||||
|
||||
self.client = client
|
||||
self.config = config if config else HttpProtocolConfig()
|
||||
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
|
||||
|
||||
self.plugins: Dict[str, HttpProtocolBasePlugin] = {}
|
||||
for klass in self.config.plugins:
|
||||
instance = klass(self.config, self.client, self.request)
|
||||
self.plugins[instance.name()] = instance
|
||||
|
||||
@staticmethod
|
||||
def now():
|
||||
return datetime.datetime.utcnow()
|
||||
|
||||
def connection_inactive_for(self):
|
||||
return (self.now() - self.last_activity).seconds
|
||||
|
||||
def is_connection_inactive(self):
|
||||
return self.connection_inactive_for() > 30
|
||||
|
||||
def run_once(self):
|
||||
"""Returns True if proxy must teardown."""
|
||||
# Prepare list of descriptors
|
||||
rlist, wlist, xlist = [self.client.conn], [], []
|
||||
read_desc, write_desc, err_desc = [self.client.conn], [], []
|
||||
if self.client.has_buffer():
|
||||
wlist.append(self.client.conn)
|
||||
if self.server and not self.server.closed:
|
||||
rlist.append(self.server.conn)
|
||||
if self.server and not self.server.closed and self.server.has_buffer():
|
||||
wlist.append(self.server.conn)
|
||||
write_desc.append(self.client.conn)
|
||||
|
||||
r, w, x = select.select(rlist, wlist, xlist, 1)
|
||||
for plugin in self.plugins.values():
|
||||
plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors()
|
||||
read_desc += plugin_read_desc
|
||||
write_desc += plugin_write_desc
|
||||
err_desc += plugin_err_desc
|
||||
|
||||
# Flush buffer from ready to write sockets
|
||||
if self.client.conn in w:
|
||||
readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1)
|
||||
|
||||
# Flush buffer for ready to write sockets
|
||||
if self.client.conn in writable:
|
||||
logger.debug('Client is ready for writes, flushing client buffer')
|
||||
try:
|
||||
self.client.flush()
|
||||
if self.server and not self.server.closed and self.server.conn in w:
|
||||
logger.debug('Server is ready for writes, flushing server buffer')
|
||||
self.server.flush()
|
||||
except BrokenPipeError:
|
||||
logging.error('BrokenPipeError when flushing buffer for client')
|
||||
return True
|
||||
|
||||
for plugin in self.plugins.values():
|
||||
plugin.flush_to_descriptors(writable)
|
||||
|
||||
# Read from ready to read sockets
|
||||
if self.client.conn in r:
|
||||
if self.client.conn in readable:
|
||||
logger.debug('Client is ready for reads, reading')
|
||||
data = self.client.recv(self.config.client_recvbuf_size)
|
||||
client_data = self.client.recv(self.config.client_recvbuf_size)
|
||||
self.last_activity = self.now()
|
||||
if not data:
|
||||
if not client_data:
|
||||
logger.debug('Client closed connection, tearing down...')
|
||||
return True
|
||||
|
||||
plugin_index = 0
|
||||
plugins = list(self.plugins.values())
|
||||
while plugin_index < len(plugins) and client_data:
|
||||
client_data = plugins[plugin_index].on_client_data(client_data)
|
||||
plugin_index += 1
|
||||
|
||||
if client_data:
|
||||
try:
|
||||
# Once we have connection to the server
|
||||
# we don't parse the http request packets
|
||||
# any further, instead just pipe incoming
|
||||
# data from client to server
|
||||
if self.server and not self.server.closed:
|
||||
self.server.queue(data)
|
||||
else:
|
||||
# Parse http request
|
||||
self.request.parse(data)
|
||||
self.request.parse(client_data)
|
||||
if self.request.state == HttpParser.states.COMPLETE:
|
||||
logger.debug('Request parser is in state complete')
|
||||
|
||||
if self.config.auth_code:
|
||||
if b'proxy-authorization' not in self.request.headers or \
|
||||
self.request.headers[b'proxy-authorization'][1] != self.config.auth_code:
|
||||
raise ProxyAuthenticationFailed()
|
||||
|
||||
# Invoke HttpProxyPlugin.handle_request
|
||||
for plugin in self.config.plugins:
|
||||
self.request = plugin.handle_request(self.request)
|
||||
|
||||
if self.request.method == b'CONNECT':
|
||||
host, port = self.request.url.path.split(COLON)
|
||||
elif self.request.url:
|
||||
host, port = self.request.url.hostname, self.request.url.port \
|
||||
if self.request.url.port else 80
|
||||
else:
|
||||
raise Exception('Invalid request\n%s' % self.request.raw)
|
||||
|
||||
if host is None and self.config.pac_file:
|
||||
self.serve_pac_file()
|
||||
for plugin in self.plugins.values():
|
||||
# TODO: Cleanup by not returning True for teardown cases
|
||||
plugin_response = plugin.on_request_complete()
|
||||
if type(plugin_response) is bool:
|
||||
return True
|
||||
|
||||
self.server = TcpServerConnection(host, port)
|
||||
try:
|
||||
logger.debug('Connecting to server %s:%s' % (host, port))
|
||||
self.server.connect()
|
||||
logger.debug('Connected to server %s:%s' % (host, port))
|
||||
except Exception as e: # TimeoutError, socket.gaierror
|
||||
self.server.closed = True
|
||||
raise ProxyConnectionFailed(host, port, repr(e))
|
||||
|
||||
# for http connect methods (https requests)
|
||||
# queue appropriate response for client
|
||||
# notifying about established connection
|
||||
if self.request.method == b'CONNECT':
|
||||
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
# for usual http requests, re-build request packet
|
||||
# and queue for the server with appropriate headers
|
||||
else:
|
||||
self.server.queue(self.request.build(
|
||||
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
|
||||
b'keep-alive'],
|
||||
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
|
||||
))
|
||||
except (ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected) as e:
|
||||
except ProxyError as e: # ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected
|
||||
# logger.exception(e)
|
||||
response = e.response(self.request)
|
||||
if response:
|
||||
|
@ -669,23 +840,11 @@ class HttpProxy(threading.Thread):
|
|||
# But is client also ready for writes?
|
||||
self.client.flush()
|
||||
raise e
|
||||
if self.server and not self.server.closed and self.server.conn in r:
|
||||
logger.debug('Server is ready for reads, reading')
|
||||
data = self.server.recv(self.config.server_recvbuf_size)
|
||||
self.last_activity = self.now()
|
||||
if not data:
|
||||
logger.debug('Server closed connection, tearing down...')
|
||||
|
||||
for plugin in self.plugins.values():
|
||||
teardown = plugin.read_from_descriptors(readable)
|
||||
if teardown:
|
||||
return True
|
||||
# parse incoming response packet
|
||||
# only for non-https requests
|
||||
if not self.request.method == b'CONNECT':
|
||||
self.response.parse(data)
|
||||
else:
|
||||
# Only purpose of increasing memory footprint is to print response length in access log
|
||||
# Not worth it? Optimize to only persist lengths?
|
||||
self.response.raw += data
|
||||
# queue data for client
|
||||
self.client.queue(data)
|
||||
|
||||
# Teardown if client buffer is empty and connection is inactive
|
||||
if self.client.buffer_size() == 0:
|
||||
|
@ -706,16 +865,17 @@ class HttpProxy(threading.Thread):
|
|||
except Exception as e:
|
||||
logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e))
|
||||
finally:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.access_log()
|
||||
|
||||
self.client.close()
|
||||
logger.debug(
|
||||
'Closed client connection with pending client buffer size %d bytes' % self.client.buffer_size())
|
||||
if self.server:
|
||||
logger.debug(
|
||||
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
|
||||
if not self.server.closed:
|
||||
self.server.close()
|
||||
self.access_log()
|
||||
logger.debug('Closed proxy for connection %r at address %r' % (self.client.conn, self.client.addr))
|
||||
logger.debug('Closed client connection with pending '
|
||||
'client buffer size %d bytes' % self.client.buffer_size())
|
||||
for plugin in self.plugins.values():
|
||||
plugin.on_client_connection_close()
|
||||
|
||||
logger.debug('Closed proxy for connection %r '
|
||||
'at address %r' % (self.client.conn, self.client.addr))
|
||||
|
||||
|
||||
class TcpServer(object):
|
||||
|
@ -797,7 +957,7 @@ class MultiCoreRequestDispatcher(TcpServer):
|
|||
self.workers.append(worker)
|
||||
|
||||
def handle(self, client):
|
||||
self.worker_queue.put((Worker.operations.HTTP_PROXY, client))
|
||||
self.worker_queue.put((Worker.operations.HTTP_PROTOCOL, client))
|
||||
|
||||
def shutdown(self):
|
||||
logger.info('Shutting down %d workers' % self.num_workers)
|
||||
|
@ -815,7 +975,7 @@ class Worker(multiprocessing.Process):
|
|||
"""
|
||||
|
||||
operations = namedtuple('WorkerOperations', (
|
||||
'HTTP_PROXY',
|
||||
'HTTP_PROTOCOL',
|
||||
'SHUTDOWN',
|
||||
))(1, 2)
|
||||
|
||||
|
@ -828,8 +988,8 @@ class Worker(multiprocessing.Process):
|
|||
while True:
|
||||
try:
|
||||
op, payload = self.work_queue.get(True, 1)
|
||||
if op == Worker.operations.HTTP_PROXY:
|
||||
proxy = HttpProxy(payload, config=self.config)
|
||||
if op == Worker.operations.HTTP_PROTOCOL:
|
||||
proxy = HttpProtocolHandler(payload, config=self.config)
|
||||
proxy.setDaemon(True)
|
||||
proxy.start()
|
||||
elif op == Worker.operations.SHUTDOWN:
|
||||
|
@ -852,6 +1012,22 @@ def set_open_file_limit(soft_limit):
|
|||
logger.debug('Open file descriptor soft limit set to %d' % soft_limit)
|
||||
|
||||
|
||||
def load_plugins(plugins) -> List:
|
||||
p = []
|
||||
plugins = plugins.split(',')
|
||||
for plugin in plugins:
|
||||
plugin = plugin.strip()
|
||||
if plugin == '':
|
||||
continue
|
||||
logging.debug('Loading plugin %s', plugin)
|
||||
module_name, klass_name = plugin.rsplit('.', 1)
|
||||
module = importlib.import_module(module_name)
|
||||
klass = getattr(module, klass_name)
|
||||
logging.info('%s initialized' % klass.__name__)
|
||||
p.append(klass)
|
||||
return p
|
||||
|
||||
|
||||
def init_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='proxy.py v%s' % __version__,
|
||||
|
@ -882,12 +1058,14 @@ def init_parser():
|
|||
parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT,
|
||||
help='Default: 1024. Maximum number of files (TCP connections) '
|
||||
'that proxy.py can open concurrently.')
|
||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
||||
help='Default: 8899. Server port.')
|
||||
parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE,
|
||||
help='A file (Proxy Auto Configuration) or string to serve when '
|
||||
'the server receives a direct file request.')
|
||||
parser.add_argument('--plugins', type=str, default=None, help='Comma separated plugins')
|
||||
parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH,
|
||||
help='Web server path to serve the PAC file.')
|
||||
parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins')
|
||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
||||
help='Default: 8899. Server port.')
|
||||
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
help='Default: 8 KB. Maximum amount of data received from the '
|
||||
'server in a single recv() operation. Bump this '
|
||||
|
@ -898,18 +1076,6 @@ def init_parser():
|
|||
return parser
|
||||
|
||||
|
||||
def load_plugins(plugins):
|
||||
p = []
|
||||
plugins = plugins.split(',')
|
||||
for plugin in plugins:
|
||||
module_name, klass_name = plugin.rsplit('.', 1)
|
||||
module = importlib.import_module(module_name)
|
||||
klass = getattr(module, klass_name)
|
||||
logging.info('%s initialized' % klass.__name__)
|
||||
p.append(klass())
|
||||
return p
|
||||
|
||||
|
||||
def main():
|
||||
if not PY3 and not UNDER_TEST:
|
||||
print(
|
||||
|
@ -924,7 +1090,6 @@ def main():
|
|||
|
||||
parser = init_parser()
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
logging.basicConfig(level=getattr(logging,
|
||||
{
|
||||
'D': 'DEBUG',
|
||||
|
@ -934,7 +1099,6 @@ def main():
|
|||
'C': 'CRITICAL'
|
||||
}[args.log_level.upper()[0]]),
|
||||
format=args.log_format)
|
||||
plugins = load_plugins(args.plugins) if args.plugins else DEFAULT_PLUGINS
|
||||
|
||||
try:
|
||||
set_open_file_limit(args.open_file_limit)
|
||||
|
@ -943,16 +1107,18 @@ def main():
|
|||
if args.basic_auth:
|
||||
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
|
||||
|
||||
config = HttpProtocolConfig(auth_code=auth_code,
|
||||
server_recvbuf_size=args.server_recvbuf_size,
|
||||
client_recvbuf_size=args.client_recvbuf_size,
|
||||
pac_file=args.pac_file,
|
||||
pac_file_url_path=args.pac_file_url_path)
|
||||
config.plugins = load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,%s' % args.plugins)
|
||||
server = MultiCoreRequestDispatcher(hostname=args.hostname,
|
||||
port=args.port,
|
||||
backlog=args.backlog,
|
||||
ipv4=args.ipv4,
|
||||
num_workers=args.num_workers,
|
||||
config=HttpProxyConfig(auth_code=auth_code,
|
||||
server_recvbuf_size=args.server_recvbuf_size,
|
||||
client_recvbuf_size=args.client_recvbuf_size,
|
||||
pac_file=args.pac_file,
|
||||
plugins=plugins))
|
||||
config=config)
|
||||
server.run()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
mock==3.0.5
|
48
tests.py
48
tests.py
|
@ -3,9 +3,9 @@
|
|||
proxy.py
|
||||
~~~~~~~~
|
||||
|
||||
HTTP Proxy Server in Python.
|
||||
HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python.
|
||||
|
||||
:copyright: (c) 2013-2018 by Abhinav Singh.
|
||||
:copyright: (c) 2013-2020 by Abhinav Singh.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import os
|
||||
|
@ -21,8 +21,8 @@ from threading import Thread
|
|||
|
||||
import proxy
|
||||
|
||||
# logging.basicConfig(level=logging.DEBUG,
|
||||
# format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
|
||||
|
||||
if sys.version_info[0] == 3: # Python3 specific imports
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
@ -123,7 +123,7 @@ class TestMultiCoreRequestDispatcher(unittest.TestCase):
|
|||
tcp_server = None
|
||||
tcp_thread = None
|
||||
|
||||
@mock.patch.object(proxy, 'HttpProxy', side_effect=mock_tcp_proxy_side_effect)
|
||||
@mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect)
|
||||
def testHttpProxyConnection(self, mock_tcp_proxy):
|
||||
try:
|
||||
self.tcp_port = get_available_port()
|
||||
|
@ -523,6 +523,7 @@ class TestProxy(unittest.TestCase):
|
|||
http_server = None
|
||||
http_server_port = None
|
||||
http_server_thread = None
|
||||
config = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -531,6 +532,8 @@ class TestProxy(unittest.TestCase):
|
|||
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
|
||||
cls.http_server_thread.setDaemon(True)
|
||||
cls.http_server_thread.start()
|
||||
cls.config = proxy.HttpProtocolConfig()
|
||||
cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -541,7 +544,7 @@ class TestProxy(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self._conn = MockTcpConnection()
|
||||
self._addr = ('127.0.0.1', 54382)
|
||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr))
|
||||
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config)
|
||||
|
||||
@mock.patch('select.select')
|
||||
@mock.patch('proxy.TcpServerConnection')
|
||||
|
@ -610,7 +613,7 @@ class TestProxy(unittest.TestCase):
|
|||
proxy.CRLF
|
||||
]))
|
||||
self.proxy.run_once()
|
||||
self.assertFalse(self.proxy.server is None)
|
||||
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
|
||||
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
mock_server_connection.assert_called_once()
|
||||
server.connect.assert_called_once()
|
||||
|
@ -659,9 +662,10 @@ class TestProxy(unittest.TestCase):
|
|||
@mock.patch('select.select')
|
||||
def test_proxy_authentication_failed(self, mock_select):
|
||||
mock_select.return_value = ([self._conn], [], [])
|
||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
|
||||
config=proxy.HttpProxyConfig(
|
||||
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
|
||||
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
|
||||
config=config)
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([
|
||||
b'GET http://abhinavsingh.com HTTP/1.1',
|
||||
b'Host: abhinavsingh.com',
|
||||
|
@ -673,14 +677,15 @@ class TestProxy(unittest.TestCase):
|
|||
@mock.patch('select.select')
|
||||
@mock.patch('proxy.TcpServerConnection')
|
||||
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
|
||||
client = proxy.TcpClientConnection(self._conn, self._addr)
|
||||
config = proxy.HttpProxyConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||
|
||||
mock_select.return_value = ([self._conn], [], [])
|
||||
server = mock_server_connection.return_value
|
||||
server.connect.return_value = True
|
||||
|
||||
self.proxy = proxy.HttpProxy(client, config=config)
|
||||
client = proxy.TcpClientConnection(self._conn, self._addr)
|
||||
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
|
||||
self.proxy = proxy.HttpProtocolHandler(client, config=config)
|
||||
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
|
||||
self.proxy.run_once()
|
||||
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
|
||||
|
@ -731,9 +736,10 @@ class TestProxy(unittest.TestCase):
|
|||
server.connect.return_value = True
|
||||
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
|
||||
|
||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
|
||||
config=proxy.HttpProxyConfig(
|
||||
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
|
||||
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
|
||||
config=config)
|
||||
self.proxy.client.conn.queue(proxy.CRLF.join([
|
||||
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
|
||||
b'Host: localhost:%d' % self.http_server_port,
|
||||
|
@ -743,7 +749,7 @@ class TestProxy(unittest.TestCase):
|
|||
proxy.CRLF
|
||||
]))
|
||||
self.proxy.run_once()
|
||||
self.assertFalse(self.proxy.server is None)
|
||||
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
|
||||
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||
mock_server_connection.assert_called_once()
|
||||
server.connect.assert_called_once()
|
||||
|
@ -780,15 +786,15 @@ class TestWorker(unittest.TestCase):
|
|||
self.queue = multiprocessing.Queue()
|
||||
self.worker = proxy.Worker(self.queue)
|
||||
|
||||
@mock.patch('proxy.HttpProxy')
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
def test_shutdown_op(self, mock_http_proxy):
|
||||
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
||||
self.worker.run() # Worker should consume the prior shutdown operation
|
||||
self.assertFalse(mock_http_proxy.called)
|
||||
|
||||
@mock.patch('proxy.HttpProxy')
|
||||
@mock.patch('proxy.HttpProtocolHandler')
|
||||
def test_spawns_http_proxy_threads(self, mock_http_proxy):
|
||||
self.queue.put((proxy.Worker.operations.HTTP_PROXY, None))
|
||||
self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None))
|
||||
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
||||
self.worker.run()
|
||||
self.assertTrue(mock_http_proxy.called)
|
||||
|
|
Loading…
Reference in New Issue