Convert HttpProxy and HttpWebServer itself into plugins.

Load plugins during test execution

Further decouple proxy/webserver logic outside of HttpProtocolHandler.

Per connection plugin instances to avoid locks

Handle BrokenPipeError and teardown if read_from_descriptors return True
This commit is contained in:
Abhinav Singh 2019-08-20 16:10:05 -07:00
parent 8732cb7151
commit 6ea42b0dd9
4 changed files with 437 additions and 258 deletions

View File

@ -1,30 +1,38 @@
from urllib import parse as urlparse
from proxy import HttpProxyPlugin, ProxyRejectRequest
from proxy import HttpProtocolBasePlugin, ProxyRequestRejected
class RedirectToCustomServerPlugin(HttpProxyPlugin):
class RedirectToCustomServerPlugin(HttpProtocolBasePlugin):
"""Modifies client request to redirect all incoming requests to a fixed server address."""
def __init__(self):
super(RedirectToCustomServerPlugin, self).__init__()
def handle_request(self, request):
if request.method != 'CONNECT':
request.url = urlparse.urlsplit(b'http://localhost:9999')
return request
def on_request_complete(self):
if self.request.method != 'CONNECT':
self.request.url = urlparse.urlsplit(b'http://localhost:9999')
class FilterByTargetDomainPlugin(HttpProxyPlugin):
class FilterByTargetDomainPlugin(HttpProtocolBasePlugin):
"""Only accepts specific requests dropping all other requests."""
def __init__(self):
super(FilterByTargetDomainPlugin, self).__init__()
self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443']
def handle_request(self, request):
def on_request_complete(self):
# TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes
# are not consistent between CONNECT and non-CONNECT requests.
if (request.method != b'CONNECT' and request.url.hostname not in self.allowed_domains) or \
(request.method == b'CONNECT' and request.url.path not in self.allowed_domains):
raise ProxyRejectRequest(status_code=418, body='I\'m a tea pot')
return request
if (self.request.method != b'CONNECT' and self.request.url.hostname not in self.allowed_domains) or \
(self.request.method == b'CONNECT' and self.request.url.path not in self.allowed_domains):
raise ProxyRequestRejected(status_code=418, body='I\'m a tea pot')
class SaveHttpResponses(HttpProtocolBasePlugin):
"""Saves Http Responses locally on disk."""
def __init__(self):
super(SaveHttpResponses, self).__init__()
def handle_response_chunk(self, chunk):
return chunk

596
proxy.py
View File

@ -22,6 +22,7 @@ import socket
import sys
import threading
from collections import namedtuple
from typing import Dict, List
if os.name != 'nt':
import resource
@ -53,7 +54,7 @@ else: # pragma: no cover
# Defaults
DEFAULT_BACKLOG = 100
DEFAULT_BASIC_AUTH = None
DEFAULT_BUFFER_SIZE = 8192
DEFAULT_BUFFER_SIZE = 1024 * 1024
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
@ -63,10 +64,12 @@ DEFAULT_IPV4 = False
DEFAULT_LOG_LEVEL = 'INFO'
DEFAULT_OPEN_FILE_LIMIT = 1024
DEFAULT_PAC_FILE = None
DEFAULT_PAC_FILE_URL_PATH = '/'
DEFAULT_NUM_WORKERS = 0
DEFAULT_PLUGINS = []
DEFAULT_PLUGINS = {}
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
# Set to True if under test
UNDER_TEST = False
@ -91,7 +94,7 @@ def bytes_(s, encoding='utf-8', errors='strict'): # pragma: no cover
version = bytes_(__version__)
CRLF, COLON, SP = b'\r\n', b':', b' '
CRLF, COLON, WHITESPACE = b'\r\n', b':', b' '
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([
@ -123,31 +126,31 @@ class ChunkParser(object):
self.chunk = b'' # Partial chunk received
self.size = None # Expected size of next following chunk
def parse(self, data):
more = True if len(data) > 0 else False
def parse(self, raw):
more = True if len(raw) > 0 else False
while more:
more, data = self.process(data)
more, raw = self.process(raw)
def process(self, data):
def process(self, raw):
if self.state == ChunkParser.states.WAITING_FOR_SIZE:
# Consume prior chunk in buffer
# in case chunk size without CRLF was received
data = self.chunk + data
raw = self.chunk + raw
self.chunk = b''
# Extract following chunk data size
line, data = HttpParser.split(data)
line, raw = HttpParser.split(raw)
if not line: # CRLF not received
self.chunk = data
data = b''
self.chunk = raw
raw = b''
else:
self.size = int(line, 16)
self.state = ChunkParser.states.WAITING_FOR_DATA
elif self.state == ChunkParser.states.WAITING_FOR_DATA:
remaining = self.size - len(self.chunk)
self.chunk += data[:remaining]
data = data[remaining:]
self.chunk += raw[:remaining]
raw = raw[remaining:]
if len(self.chunk) == self.size:
data = data[len(CRLF):]
raw = raw[len(CRLF):]
self.body += self.chunk
if self.size == 0:
self.state = ChunkParser.states.COMPLETE
@ -155,7 +158,7 @@ class ChunkParser(object):
self.state = ChunkParser.states.WAITING_FOR_SIZE
self.chunk = b''
self.size = None
return len(data) > 0, data
return len(raw) > 0, raw
class HttpParser(object):
@ -179,7 +182,7 @@ class HttpParser(object):
self.type = parser_type
self.state = HttpParser.states.INITIALIZED
self.raw = b''
self.bytes = b''
self.buffer = b''
self.headers = dict()
@ -193,22 +196,38 @@ class HttpParser(object):
self.chunk_parser = None
# This cleans up developer APIs as Python urlparse.urlsplit behaves differently
# for incoming proxy request and incoming web request. Web request is the one
# which is broken.
self.host = None
self.port = None
def set_host_port(self):
if self.type == HttpParser.types.REQUEST_PARSER:
if self.method == b'CONNECT':
self.host, self.port = self.url.path.split(COLON)
elif self.url:
self.host, self.port = self.url.hostname, self.url.port \
if self.url.port else 80
else:
raise Exception('Invalid request\n%s' % self.bytes)
def is_chunked_encoded_response(self):
return self.type == HttpParser.types.RESPONSE_PARSER and \
b'transfer-encoding' in self.headers and \
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
def parse(self, data):
self.raw += data
data = self.buffer + data
def parse(self, raw):
self.bytes += raw
raw = self.buffer + raw
self.buffer = b''
more = True if len(data) > 0 else False
more = True if len(raw) > 0 else False
while more:
more, data = self.process(data)
self.buffer = data
more, raw = self.process(raw)
self.buffer = raw
def process(self, data):
def process(self, raw):
if self.state in (HttpParser.states.HEADERS_COMPLETE,
HttpParser.states.RCVING_BODY,
HttpParser.states.COMPLETE) and \
@ -218,22 +237,22 @@ class HttpParser(object):
if b'content-length' in self.headers:
self.state = HttpParser.states.RCVING_BODY
self.body += data
self.body += raw
if len(self.body) >= int(self.headers[b'content-length'][1]):
self.state = HttpParser.states.COMPLETE
elif self.is_chunked_encoded_response():
if not self.chunk_parser:
self.chunk_parser = ChunkParser()
self.chunk_parser.parse(data)
self.chunk_parser.parse(raw)
if self.chunk_parser.state == ChunkParser.states.COMPLETE:
self.body = self.chunk_parser.body
self.state = HttpParser.states.COMPLETE
return False, b''
line, data = HttpParser.split(data)
line, raw = HttpParser.split(raw)
if line is False:
return line, data
return line, raw
if self.state == HttpParser.states.INITIALIZED:
self.process_line(line)
@ -245,7 +264,7 @@ class HttpParser(object):
if self.state == HttpParser.states.LINE_RCVD and \
self.type == HttpParser.types.REQUEST_PARSER and \
self.method == b'CONNECT' and \
data == CRLF:
raw == CRLF:
self.state = HttpParser.states.COMPLETE
# When raw request has ended with \r\n\r\n and no more http headers are expected
@ -254,7 +273,7 @@ class HttpParser(object):
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
self.type == HttpParser.types.REQUEST_PARSER and \
self.method != b'POST' and \
self.raw.endswith(CRLF * 2):
self.bytes.endswith(CRLF * 2):
self.state = HttpParser.states.COMPLETE
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
self.type == HttpParser.types.REQUEST_PARSER and \
@ -262,13 +281,13 @@ class HttpParser(object):
(b'content-length' not in self.headers or
(b'content-length' in self.headers and
int(self.headers[b'content-length'][1]) == 0)) and \
self.raw.endswith(CRLF * 2):
self.bytes.endswith(CRLF * 2):
self.state = HttpParser.states.COMPLETE
return len(data) > 0, data
return len(raw) > 0, raw
def process_line(self, data):
line = data.split(SP)
def process_line(self, raw):
line = raw.split(WHITESPACE)
if self.type == HttpParser.types.REQUEST_PARSER:
self.method = line[0].upper()
self.url = urlparse.urlsplit(line[1])
@ -277,17 +296,18 @@ class HttpParser(object):
self.version = line[0]
self.code = line[1]
self.reason = b' '.join(line[2:])
self.set_host_port()
self.state = HttpParser.states.LINE_RCVD
def process_header(self, data):
if len(data) == 0:
def process_header(self, raw):
if len(raw) == 0:
if self.state == HttpParser.states.RCVING_HEADERS:
self.state = HttpParser.states.HEADERS_COMPLETE
elif self.state == HttpParser.states.LINE_RCVD:
self.state = HttpParser.states.RCVING_HEADERS
else:
self.state = HttpParser.states.RCVING_HEADERS
parts = data.split(COLON)
parts = raw.split(COLON)
key = parts[0].strip()
value = COLON.join(parts[1:]).strip()
self.headers[key.lower()] = (key, value)
@ -331,13 +351,23 @@ class HttpParser(object):
return k + b': ' + v
@staticmethod
def split(data):
pos = data.find(CRLF)
def split(raw):
pos = raw.find(CRLF)
if pos == -1:
return False, data
line = data[:pos]
data = data[pos + len(CRLF):]
return line, data
return False, raw
line = raw[:pos]
raw = raw[pos + len(CRLF):]
return line, raw
###################################################################################
# HttpParser was originally written to parse the incoming raw Http requests.
# Since request / response objects passed to HttpProtocolBasePlugin methods
# are also HttpParser objects, methods below were added to simplify developer API.
####################################################################################
def has_upstream_server(self):
"""Host field SHOULD be None for incoming local WebServer requests."""
return True if self.host is not None else False
class TcpConnection(object):
@ -488,69 +518,224 @@ class ProxyRequestRejected(ProxyError):
return CRLF.join(pkt) if len(pkt) > 0 else None
class HttpProxyPlugin(object):
"""Base HttpProxy Plugin class."""
def __init__(self):
pass
def handle_request(self, request):
"""Handle client request (HttpParser).
Return optionally modified client request (HttpParser) object.
"""
return request
def handle_response(self, data):
"""Handle data chunks as received from the server."""
return data
class HttpProxyConfig(object):
"""Holds various configuration values applicable to HttpProxy.
class HttpProtocolConfig(object):
"""Holds various configuration values applicable to HttpProtocolHandler.
This config class helps us avoid passing around bunch of key/value pairs across methods.
"""
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, plugins=None):
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE,
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None):
self.auth_code = auth_code
self.server_recvbuf_size = server_recvbuf_size
self.client_recvbuf_size = client_recvbuf_size
self.pac_file = pac_file
self.pac_file_url_path = pac_file_url_path
if plugins is None:
plugins = DEFAULT_PLUGINS
self.plugins = plugins
self.plugins: Dict = plugins
class HttpProxy(threading.Thread):
"""HTTP proxy implementation.
class HttpProtocolBasePlugin(object):
"""Base HttpProtocolHandler Plugin class.
Accepts `Client` connection object and act as a proxy between client and server.
"""
def __init__(self, client, config=None):
super(HttpProxy, self).__init__()
self.start_time = self.now()
self.last_activity = self.start_time
Implement various lifecycle event methods to customize behavior."""
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
self.config = config
self.client = client
self.server = None
self.config = config if config else HttpProxyConfig()
self.request = request
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
def name(self) -> str:
"""A unique name for your plugin.
Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__
def get_descriptors(self):
return [], [], []
def flush_to_descriptors(self, w):
pass
def read_from_descriptors(self, r):
pass
def on_client_data(self, raw: bytes):
return raw
def on_request_complete(self):
"""Called right after client request parser has reached COMPLETE state."""
pass
def handle_response_chunk(self, chunk: bytes):
"""Handle data chunks as received from the server.
Return optionally modified chunk to return back to client."""
return chunk
def access_log(self):
pass
def on_client_connection_close(self):
pass
class HttpProxyPlugin(HttpProtocolBasePlugin):
"""HttpProtocolHandler plugin which implements HttpProxy specifications."""
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
super(HttpProxyPlugin, self).__init__(config, client, request)
self.server = None
self.response = HttpParser(HttpParser.types.RESPONSE_PARSER)
@staticmethod
def now():
return datetime.datetime.utcnow()
def get_descriptors(self):
if not self.request.has_upstream_server():
return [], [], []
def connection_inactive_for(self):
return (self.now() - self.last_activity).seconds
r, w, x = [], [], []
if self.server and not self.server.closed:
r.append(self.server.conn)
if self.server and not self.server.closed and self.server.has_buffer():
w.append(self.server.conn)
return r, w, x
def is_connection_inactive(self):
return self.connection_inactive_for() > 30
def flush_to_descriptors(self, w):
if not self.request.has_upstream_server():
return
if self.server and not self.server.closed and self.server.conn in w:
logger.debug('Server is ready for writes, flushing server buffer')
self.server.flush()
def read_from_descriptors(self, r):
if not self.request.has_upstream_server():
return
if self.server and not self.server.closed and self.server.conn in r:
logger.debug('Server is ready for reads, reading')
raw = self.server.recv(self.config.server_recvbuf_size)
# self.last_activity = HttpProtocolHandler.now()
if not raw:
logger.debug('Server closed connection, tearing down...')
return True
# parse incoming response packet
# only for non-https requests
if not self.request.method == b'CONNECT':
self.response.parse(raw)
else:
# Only purpose of increasing memory footprint is to print response length in access log
# Not worth it? Optimize to only persist lengths?
self.response.bytes += raw
# queue raw data for client
self.client.queue(raw)
def on_client_connection_close(self):
if not self.request.has_upstream_server():
return
if self.server:
logger.debug(
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
if not self.server.closed:
self.server.close()
def on_client_data(self, raw):
if not self.request.has_upstream_server():
return raw
if self.server and not self.server.closed:
self.server.queue(raw)
return None
else:
return raw
def on_request_complete(self):
if not self.request.has_upstream_server():
return
self.authenticate(self.request.headers)
self.connect_upstream(self.request.host, self.request.port)
# for http connect methods (https requests)
# queue appropriate response for client
# notifying about established connection
if self.request.method == b'CONNECT':
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
# for general http requests, re-build request packet
# and queue for the server with appropriate headers
else:
self.server.queue(self.request.build(
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
b'keep-alive'],
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
))
def access_log(self):
if not self.request.has_upstream_server():
return
host, port = self.server.addr if self.server else (None, None)
if self.request.method == b'CONNECT':
logger.info(
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(host),
text_(port), len(self.response.bytes)))
elif self.request.method:
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
len(self.response.bytes)))
def authenticate(self, headers):
if self.config.auth_code:
if b'proxy-authorization' not in headers or \
headers[b'proxy-authorization'][1] != self.config.auth_code:
raise ProxyAuthenticationFailed()
def connect_upstream(self, host, port):
self.server = TcpServerConnection(host, port)
try:
logger.debug('Connecting to upstream %s:%s' % (host, port))
self.server.connect()
logger.debug('Connected to upstream %s:%s' % (host, port))
except Exception as e: # TimeoutError, socket.gaierror
self.server.closed = True
raise ProxyConnectionFailed(host, port, repr(e))
class HttpWebServerPlugin(HttpProtocolBasePlugin):
"""HttpProtocolHandler plugin which handles incoming requests to local webserver."""
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
super(HttpWebServerPlugin, self).__init__(config, client, request)
def on_request_complete(self):
if self.request.has_upstream_server():
return
if self.config.pac_file and \
self.request.url.path == self.config.pac_file_url_path:
self.serve_pac_file()
else:
self.client.queue(CRLF.join([
b'HTTP/1.1 200 OK',
b'Server: proxy.py v%s' % version,
b'Content-Length: 42',
b'Connection: Close',
CRLF,
b'https://github.com/abhinavsingh/proxy.py'
]))
self.client.flush()
return True
def access_log(self):
if self.request.has_upstream_server():
return
logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(self.request.build_url())))
def serve_pac_file(self):
self.client.queue(PAC_FILE_RESPONSE_PREFIX)
@ -563,105 +748,91 @@ class HttpProxy(threading.Thread):
self.client.queue(self.config.pac_file)
self.client.flush()
def access_log(self):
host, port = self.server.addr if self.server else (None, None)
if self.request.method == b'CONNECT':
logger.info(
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(host),
text_(port), len(self.response.raw)))
elif self.request.method:
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
len(self.response.raw)))
class HttpProtocolHandler(threading.Thread):
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
Accepts `Client` connection object, manages plugin invocations.
"""
def __init__(self, client, config=None):
super(HttpProtocolHandler, self).__init__()
self.start_time = self.now()
self.last_activity = self.start_time
self.client = client
self.config = config if config else HttpProtocolConfig()
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
self.plugins: Dict[str, HttpProtocolBasePlugin] = {}
for klass in self.config.plugins:
instance = klass(self.config, self.client, self.request)
self.plugins[instance.name()] = instance
@staticmethod
def now():
return datetime.datetime.utcnow()
def connection_inactive_for(self):
return (self.now() - self.last_activity).seconds
def is_connection_inactive(self):
return self.connection_inactive_for() > 30
def run_once(self):
"""Returns True if proxy must teardown."""
# Prepare list of descriptors
rlist, wlist, xlist = [self.client.conn], [], []
read_desc, write_desc, err_desc = [self.client.conn], [], []
if self.client.has_buffer():
wlist.append(self.client.conn)
if self.server and not self.server.closed:
rlist.append(self.server.conn)
if self.server and not self.server.closed and self.server.has_buffer():
wlist.append(self.server.conn)
write_desc.append(self.client.conn)
r, w, x = select.select(rlist, wlist, xlist, 1)
for plugin in self.plugins.values():
plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors()
read_desc += plugin_read_desc
write_desc += plugin_write_desc
err_desc += plugin_err_desc
# Flush buffer from ready to write sockets
if self.client.conn in w:
readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1)
# Flush buffer for ready to write sockets
if self.client.conn in writable:
logger.debug('Client is ready for writes, flushing client buffer')
try:
self.client.flush()
if self.server and not self.server.closed and self.server.conn in w:
logger.debug('Server is ready for writes, flushing server buffer')
self.server.flush()
except BrokenPipeError:
logging.error('BrokenPipeError when flushing buffer for client')
return True
for plugin in self.plugins.values():
plugin.flush_to_descriptors(writable)
# Read from ready to read sockets
if self.client.conn in r:
if self.client.conn in readable:
logger.debug('Client is ready for reads, reading')
data = self.client.recv(self.config.client_recvbuf_size)
client_data = self.client.recv(self.config.client_recvbuf_size)
self.last_activity = self.now()
if not data:
if not client_data:
logger.debug('Client closed connection, tearing down...')
return True
plugin_index = 0
plugins = list(self.plugins.values())
while plugin_index < len(plugins) and client_data:
client_data = plugins[plugin_index].on_client_data(client_data)
plugin_index += 1
if client_data:
try:
# Once we have connection to the server
# we don't parse the http request packets
# any further, instead just pipe incoming
# data from client to server
if self.server and not self.server.closed:
self.server.queue(data)
else:
# Parse http request
self.request.parse(data)
self.request.parse(client_data)
if self.request.state == HttpParser.states.COMPLETE:
logger.debug('Request parser is in state complete')
if self.config.auth_code:
if b'proxy-authorization' not in self.request.headers or \
self.request.headers[b'proxy-authorization'][1] != self.config.auth_code:
raise ProxyAuthenticationFailed()
# Invoke HttpProxyPlugin.handle_request
for plugin in self.config.plugins:
self.request = plugin.handle_request(self.request)
if self.request.method == b'CONNECT':
host, port = self.request.url.path.split(COLON)
elif self.request.url:
host, port = self.request.url.hostname, self.request.url.port \
if self.request.url.port else 80
else:
raise Exception('Invalid request\n%s' % self.request.raw)
if host is None and self.config.pac_file:
self.serve_pac_file()
for plugin in self.plugins.values():
# TODO: Cleanup by not returning True for teardown cases
plugin_response = plugin.on_request_complete()
if type(plugin_response) is bool:
return True
self.server = TcpServerConnection(host, port)
try:
logger.debug('Connecting to server %s:%s' % (host, port))
self.server.connect()
logger.debug('Connected to server %s:%s' % (host, port))
except Exception as e: # TimeoutError, socket.gaierror
self.server.closed = True
raise ProxyConnectionFailed(host, port, repr(e))
# for http connect methods (https requests)
# queue appropriate response for client
# notifying about established connection
if self.request.method == b'CONNECT':
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
# for usual http requests, re-build request packet
# and queue for the server with appropriate headers
else:
self.server.queue(self.request.build(
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
b'keep-alive'],
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
))
except (ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected) as e:
except ProxyError as e: # ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected
# logger.exception(e)
response = e.response(self.request)
if response:
@ -669,23 +840,11 @@ class HttpProxy(threading.Thread):
# But is client also ready for writes?
self.client.flush()
raise e
if self.server and not self.server.closed and self.server.conn in r:
logger.debug('Server is ready for reads, reading')
data = self.server.recv(self.config.server_recvbuf_size)
self.last_activity = self.now()
if not data:
logger.debug('Server closed connection, tearing down...')
for plugin in self.plugins.values():
teardown = plugin.read_from_descriptors(readable)
if teardown:
return True
# parse incoming response packet
# only for non-https requests
if not self.request.method == b'CONNECT':
self.response.parse(data)
else:
# Only purpose of increasing memory footprint is to print response length in access log
# Not worth it? Optimize to only persist lengths?
self.response.raw += data
# queue data for client
self.client.queue(data)
# Teardown if client buffer is empty and connection is inactive
if self.client.buffer_size() == 0:
@ -706,16 +865,17 @@ class HttpProxy(threading.Thread):
except Exception as e:
logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e))
finally:
for plugin in self.plugins.values():
plugin.access_log()
self.client.close()
logger.debug(
'Closed client connection with pending client buffer size %d bytes' % self.client.buffer_size())
if self.server:
logger.debug(
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
if not self.server.closed:
self.server.close()
self.access_log()
logger.debug('Closed proxy for connection %r at address %r' % (self.client.conn, self.client.addr))
logger.debug('Closed client connection with pending '
'client buffer size %d bytes' % self.client.buffer_size())
for plugin in self.plugins.values():
plugin.on_client_connection_close()
logger.debug('Closed proxy for connection %r '
'at address %r' % (self.client.conn, self.client.addr))
class TcpServer(object):
@ -797,7 +957,7 @@ class MultiCoreRequestDispatcher(TcpServer):
self.workers.append(worker)
def handle(self, client):
self.worker_queue.put((Worker.operations.HTTP_PROXY, client))
self.worker_queue.put((Worker.operations.HTTP_PROTOCOL, client))
def shutdown(self):
logger.info('Shutting down %d workers' % self.num_workers)
@ -815,7 +975,7 @@ class Worker(multiprocessing.Process):
"""
operations = namedtuple('WorkerOperations', (
'HTTP_PROXY',
'HTTP_PROTOCOL',
'SHUTDOWN',
))(1, 2)
@ -828,8 +988,8 @@ class Worker(multiprocessing.Process):
while True:
try:
op, payload = self.work_queue.get(True, 1)
if op == Worker.operations.HTTP_PROXY:
proxy = HttpProxy(payload, config=self.config)
if op == Worker.operations.HTTP_PROTOCOL:
proxy = HttpProtocolHandler(payload, config=self.config)
proxy.setDaemon(True)
proxy.start()
elif op == Worker.operations.SHUTDOWN:
@ -852,6 +1012,22 @@ def set_open_file_limit(soft_limit):
logger.debug('Open file descriptor soft limit set to %d' % soft_limit)
def load_plugins(plugins) -> List:
p = []
plugins = plugins.split(',')
for plugin in plugins:
plugin = plugin.strip()
if plugin == '':
continue
logging.debug('Loading plugin %s', plugin)
module_name, klass_name = plugin.rsplit('.', 1)
module = importlib.import_module(module_name)
klass = getattr(module, klass_name)
logging.info('%s initialized' % klass.__name__)
p.append(klass)
return p
def init_parser():
parser = argparse.ArgumentParser(
description='proxy.py v%s' % __version__,
@ -882,12 +1058,14 @@ def init_parser():
parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT,
help='Default: 1024. Maximum number of files (TCP connections) '
'that proxy.py can open concurrently.')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Default: 8899. Server port.')
parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE,
help='A file (Proxy Auto Configuration) or string to serve when '
'the server receives a direct file request.')
parser.add_argument('--plugins', type=str, default=None, help='Comma separated plugins')
parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH,
help='Web server path to serve the PAC file.')
parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Default: 8899. Server port.')
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
help='Default: 8 KB. Maximum amount of data received from the '
'server in a single recv() operation. Bump this '
@ -898,18 +1076,6 @@ def init_parser():
return parser
def load_plugins(plugins):
p = []
plugins = plugins.split(',')
for plugin in plugins:
module_name, klass_name = plugin.rsplit('.', 1)
module = importlib.import_module(module_name)
klass = getattr(module, klass_name)
logging.info('%s initialized' % klass.__name__)
p.append(klass())
return p
def main():
if not PY3 and not UNDER_TEST:
print(
@ -924,7 +1090,6 @@ def main():
parser = init_parser()
args = parser.parse_args(sys.argv[1:])
logging.basicConfig(level=getattr(logging,
{
'D': 'DEBUG',
@ -934,7 +1099,6 @@ def main():
'C': 'CRITICAL'
}[args.log_level.upper()[0]]),
format=args.log_format)
plugins = load_plugins(args.plugins) if args.plugins else DEFAULT_PLUGINS
try:
set_open_file_limit(args.open_file_limit)
@ -943,16 +1107,18 @@ def main():
if args.basic_auth:
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
config = HttpProtocolConfig(auth_code=auth_code,
server_recvbuf_size=args.server_recvbuf_size,
client_recvbuf_size=args.client_recvbuf_size,
pac_file=args.pac_file,
pac_file_url_path=args.pac_file_url_path)
config.plugins = load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,%s' % args.plugins)
server = MultiCoreRequestDispatcher(hostname=args.hostname,
port=args.port,
backlog=args.backlog,
ipv4=args.ipv4,
num_workers=args.num_workers,
config=HttpProxyConfig(auth_code=auth_code,
server_recvbuf_size=args.server_recvbuf_size,
client_recvbuf_size=args.client_recvbuf_size,
pac_file=args.pac_file,
plugins=plugins))
config=config)
server.run()
except KeyboardInterrupt:
pass

View File

@ -1 +0,0 @@
mock==3.0.5

View File

@ -3,9 +3,9 @@
proxy.py
~~~~~~~~
HTTP Proxy Server in Python.
HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python.
:copyright: (c) 2013-2018 by Abhinav Singh.
:copyright: (c) 2013-2020 by Abhinav Singh.
:license: BSD, see LICENSE for more details.
"""
import os
@ -21,8 +21,8 @@ from threading import Thread
import proxy
# logging.basicConfig(level=logging.DEBUG,
# format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
if sys.version_info[0] == 3: # Python3 specific imports
from http.server import HTTPServer, BaseHTTPRequestHandler
@ -123,7 +123,7 @@ class TestMultiCoreRequestDispatcher(unittest.TestCase):
tcp_server = None
tcp_thread = None
@mock.patch.object(proxy, 'HttpProxy', side_effect=mock_tcp_proxy_side_effect)
@mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect)
def testHttpProxyConnection(self, mock_tcp_proxy):
try:
self.tcp_port = get_available_port()
@ -523,6 +523,7 @@ class TestProxy(unittest.TestCase):
http_server = None
http_server_port = None
http_server_thread = None
config = None
@classmethod
def setUpClass(cls):
@ -531,6 +532,8 @@ class TestProxy(unittest.TestCase):
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
cls.http_server_thread.setDaemon(True)
cls.http_server_thread.start()
cls.config = proxy.HttpProtocolConfig()
cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
@classmethod
def tearDownClass(cls):
@ -541,7 +544,7 @@ class TestProxy(unittest.TestCase):
def setUp(self):
self._conn = MockTcpConnection()
self._addr = ('127.0.0.1', 54382)
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr))
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config)
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
@ -610,7 +613,7 @@ class TestProxy(unittest.TestCase):
proxy.CRLF
]))
self.proxy.run_once()
self.assertFalse(self.proxy.server is None)
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
@ -659,9 +662,10 @@ class TestProxy(unittest.TestCase):
@mock.patch('select.select')
def test_proxy_authentication_failed(self, mock_select):
mock_select.return_value = ([self._conn], [], [])
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
config=proxy.HttpProxyConfig(
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET http://abhinavsingh.com HTTP/1.1',
b'Host: abhinavsingh.com',
@ -673,14 +677,15 @@ class TestProxy(unittest.TestCase):
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
client = proxy.TcpClientConnection(self._conn, self._addr)
config = proxy.HttpProxyConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
mock_select.return_value = ([self._conn], [], [])
server = mock_server_connection.return_value
server.connect.return_value = True
self.proxy = proxy.HttpProxy(client, config=config)
client = proxy.TcpClientConnection(self._conn, self._addr)
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(client, config=config)
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
@ -731,9 +736,10 @@ class TestProxy(unittest.TestCase):
server.connect.return_value = True
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
config=proxy.HttpProxyConfig(
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
b'Host: localhost:%d' % self.http_server_port,
@ -743,7 +749,7 @@ class TestProxy(unittest.TestCase):
proxy.CRLF
]))
self.proxy.run_once()
self.assertFalse(self.proxy.server is None)
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
@ -780,15 +786,15 @@ class TestWorker(unittest.TestCase):
self.queue = multiprocessing.Queue()
self.worker = proxy.Worker(self.queue)
@mock.patch('proxy.HttpProxy')
@mock.patch('proxy.HttpProtocolHandler')
def test_shutdown_op(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run() # Worker should consume the prior shutdown operation
self.assertFalse(mock_http_proxy.called)
@mock.patch('proxy.HttpProxy')
@mock.patch('proxy.HttpProtocolHandler')
def test_spawns_http_proxy_threads(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.HTTP_PROXY, None))
self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None))
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run()
self.assertTrue(mock_http_proxy.called)