Convert HttpProxy and HttpWebServer itself into plugins.
Load plugins during test execution Further decouple proxy/webserver logic outside of HttpProtocolHandler. Per connection plugin instances to avoid locks Handle BrokenPipeError and teardown if read_from_descriptors return True
This commit is contained in:
parent
8732cb7151
commit
6ea42b0dd9
|
@ -1,30 +1,38 @@
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
from proxy import HttpProxyPlugin, ProxyRejectRequest
|
from proxy import HttpProtocolBasePlugin, ProxyRequestRejected
|
||||||
|
|
||||||
|
|
||||||
class RedirectToCustomServerPlugin(HttpProxyPlugin):
|
class RedirectToCustomServerPlugin(HttpProtocolBasePlugin):
|
||||||
"""Modifies client request to redirect all incoming requests to a fixed server address."""
|
"""Modifies client request to redirect all incoming requests to a fixed server address."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(RedirectToCustomServerPlugin, self).__init__()
|
super(RedirectToCustomServerPlugin, self).__init__()
|
||||||
|
|
||||||
def handle_request(self, request):
|
def on_request_complete(self):
|
||||||
if request.method != 'CONNECT':
|
if self.request.method != 'CONNECT':
|
||||||
request.url = urlparse.urlsplit(b'http://localhost:9999')
|
self.request.url = urlparse.urlsplit(b'http://localhost:9999')
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
class FilterByTargetDomainPlugin(HttpProxyPlugin):
|
class FilterByTargetDomainPlugin(HttpProtocolBasePlugin):
|
||||||
"""Only accepts specific requests dropping all other requests."""
|
"""Only accepts specific requests dropping all other requests."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(FilterByTargetDomainPlugin, self).__init__()
|
super(FilterByTargetDomainPlugin, self).__init__()
|
||||||
self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443']
|
self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443']
|
||||||
|
|
||||||
def handle_request(self, request):
|
def on_request_complete(self):
|
||||||
# TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes
|
# TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes
|
||||||
# are not consistent between CONNECT and non-CONNECT requests.
|
# are not consistent between CONNECT and non-CONNECT requests.
|
||||||
if (request.method != b'CONNECT' and request.url.hostname not in self.allowed_domains) or \
|
if (self.request.method != b'CONNECT' and self.request.url.hostname not in self.allowed_domains) or \
|
||||||
(request.method == b'CONNECT' and request.url.path not in self.allowed_domains):
|
(self.request.method == b'CONNECT' and self.request.url.path not in self.allowed_domains):
|
||||||
raise ProxyRejectRequest(status_code=418, body='I\'m a tea pot')
|
raise ProxyRequestRejected(status_code=418, body='I\'m a tea pot')
|
||||||
return request
|
|
||||||
|
|
||||||
|
class SaveHttpResponses(HttpProtocolBasePlugin):
|
||||||
|
"""Saves Http Responses locally on disk."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SaveHttpResponses, self).__init__()
|
||||||
|
|
||||||
|
def handle_response_chunk(self, chunk):
|
||||||
|
return chunk
|
||||||
|
|
614
proxy.py
614
proxy.py
|
@ -22,6 +22,7 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
if os.name != 'nt':
|
if os.name != 'nt':
|
||||||
import resource
|
import resource
|
||||||
|
@ -53,7 +54,7 @@ else: # pragma: no cover
|
||||||
# Defaults
|
# Defaults
|
||||||
DEFAULT_BACKLOG = 100
|
DEFAULT_BACKLOG = 100
|
||||||
DEFAULT_BASIC_AUTH = None
|
DEFAULT_BASIC_AUTH = None
|
||||||
DEFAULT_BUFFER_SIZE = 8192
|
DEFAULT_BUFFER_SIZE = 1024 * 1024
|
||||||
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||||
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||||
DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
|
DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
|
||||||
|
@ -63,10 +64,12 @@ DEFAULT_IPV4 = False
|
||||||
DEFAULT_LOG_LEVEL = 'INFO'
|
DEFAULT_LOG_LEVEL = 'INFO'
|
||||||
DEFAULT_OPEN_FILE_LIMIT = 1024
|
DEFAULT_OPEN_FILE_LIMIT = 1024
|
||||||
DEFAULT_PAC_FILE = None
|
DEFAULT_PAC_FILE = None
|
||||||
|
DEFAULT_PAC_FILE_URL_PATH = '/'
|
||||||
DEFAULT_NUM_WORKERS = 0
|
DEFAULT_NUM_WORKERS = 0
|
||||||
DEFAULT_PLUGINS = []
|
DEFAULT_PLUGINS = {}
|
||||||
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
|
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
|
||||||
|
|
||||||
|
# Set to True if under test
|
||||||
UNDER_TEST = False
|
UNDER_TEST = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +94,7 @@ def bytes_(s, encoding='utf-8', errors='strict'): # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
version = bytes_(__version__)
|
version = bytes_(__version__)
|
||||||
CRLF, COLON, SP = b'\r\n', b':', b' '
|
CRLF, COLON, WHITESPACE = b'\r\n', b':', b' '
|
||||||
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
|
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
|
||||||
|
|
||||||
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([
|
PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([
|
||||||
|
@ -123,31 +126,31 @@ class ChunkParser(object):
|
||||||
self.chunk = b'' # Partial chunk received
|
self.chunk = b'' # Partial chunk received
|
||||||
self.size = None # Expected size of next following chunk
|
self.size = None # Expected size of next following chunk
|
||||||
|
|
||||||
def parse(self, data):
|
def parse(self, raw):
|
||||||
more = True if len(data) > 0 else False
|
more = True if len(raw) > 0 else False
|
||||||
while more:
|
while more:
|
||||||
more, data = self.process(data)
|
more, raw = self.process(raw)
|
||||||
|
|
||||||
def process(self, data):
|
def process(self, raw):
|
||||||
if self.state == ChunkParser.states.WAITING_FOR_SIZE:
|
if self.state == ChunkParser.states.WAITING_FOR_SIZE:
|
||||||
# Consume prior chunk in buffer
|
# Consume prior chunk in buffer
|
||||||
# in case chunk size without CRLF was received
|
# in case chunk size without CRLF was received
|
||||||
data = self.chunk + data
|
raw = self.chunk + raw
|
||||||
self.chunk = b''
|
self.chunk = b''
|
||||||
# Extract following chunk data size
|
# Extract following chunk data size
|
||||||
line, data = HttpParser.split(data)
|
line, raw = HttpParser.split(raw)
|
||||||
if not line: # CRLF not received
|
if not line: # CRLF not received
|
||||||
self.chunk = data
|
self.chunk = raw
|
||||||
data = b''
|
raw = b''
|
||||||
else:
|
else:
|
||||||
self.size = int(line, 16)
|
self.size = int(line, 16)
|
||||||
self.state = ChunkParser.states.WAITING_FOR_DATA
|
self.state = ChunkParser.states.WAITING_FOR_DATA
|
||||||
elif self.state == ChunkParser.states.WAITING_FOR_DATA:
|
elif self.state == ChunkParser.states.WAITING_FOR_DATA:
|
||||||
remaining = self.size - len(self.chunk)
|
remaining = self.size - len(self.chunk)
|
||||||
self.chunk += data[:remaining]
|
self.chunk += raw[:remaining]
|
||||||
data = data[remaining:]
|
raw = raw[remaining:]
|
||||||
if len(self.chunk) == self.size:
|
if len(self.chunk) == self.size:
|
||||||
data = data[len(CRLF):]
|
raw = raw[len(CRLF):]
|
||||||
self.body += self.chunk
|
self.body += self.chunk
|
||||||
if self.size == 0:
|
if self.size == 0:
|
||||||
self.state = ChunkParser.states.COMPLETE
|
self.state = ChunkParser.states.COMPLETE
|
||||||
|
@ -155,7 +158,7 @@ class ChunkParser(object):
|
||||||
self.state = ChunkParser.states.WAITING_FOR_SIZE
|
self.state = ChunkParser.states.WAITING_FOR_SIZE
|
||||||
self.chunk = b''
|
self.chunk = b''
|
||||||
self.size = None
|
self.size = None
|
||||||
return len(data) > 0, data
|
return len(raw) > 0, raw
|
||||||
|
|
||||||
|
|
||||||
class HttpParser(object):
|
class HttpParser(object):
|
||||||
|
@ -179,7 +182,7 @@ class HttpParser(object):
|
||||||
self.type = parser_type
|
self.type = parser_type
|
||||||
self.state = HttpParser.states.INITIALIZED
|
self.state = HttpParser.states.INITIALIZED
|
||||||
|
|
||||||
self.raw = b''
|
self.bytes = b''
|
||||||
self.buffer = b''
|
self.buffer = b''
|
||||||
|
|
||||||
self.headers = dict()
|
self.headers = dict()
|
||||||
|
@ -193,22 +196,38 @@ class HttpParser(object):
|
||||||
|
|
||||||
self.chunk_parser = None
|
self.chunk_parser = None
|
||||||
|
|
||||||
|
# This cleans up developer APIs as Python urlparse.urlsplit behaves differently
|
||||||
|
# for incoming proxy request and incoming web request. Web request is the one
|
||||||
|
# which is broken.
|
||||||
|
self.host = None
|
||||||
|
self.port = None
|
||||||
|
|
||||||
|
def set_host_port(self):
|
||||||
|
if self.type == HttpParser.types.REQUEST_PARSER:
|
||||||
|
if self.method == b'CONNECT':
|
||||||
|
self.host, self.port = self.url.path.split(COLON)
|
||||||
|
elif self.url:
|
||||||
|
self.host, self.port = self.url.hostname, self.url.port \
|
||||||
|
if self.url.port else 80
|
||||||
|
else:
|
||||||
|
raise Exception('Invalid request\n%s' % self.bytes)
|
||||||
|
|
||||||
def is_chunked_encoded_response(self):
|
def is_chunked_encoded_response(self):
|
||||||
return self.type == HttpParser.types.RESPONSE_PARSER and \
|
return self.type == HttpParser.types.RESPONSE_PARSER and \
|
||||||
b'transfer-encoding' in self.headers and \
|
b'transfer-encoding' in self.headers and \
|
||||||
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
|
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
|
||||||
|
|
||||||
def parse(self, data):
|
def parse(self, raw):
|
||||||
self.raw += data
|
self.bytes += raw
|
||||||
data = self.buffer + data
|
raw = self.buffer + raw
|
||||||
self.buffer = b''
|
self.buffer = b''
|
||||||
|
|
||||||
more = True if len(data) > 0 else False
|
more = True if len(raw) > 0 else False
|
||||||
while more:
|
while more:
|
||||||
more, data = self.process(data)
|
more, raw = self.process(raw)
|
||||||
self.buffer = data
|
self.buffer = raw
|
||||||
|
|
||||||
def process(self, data):
|
def process(self, raw):
|
||||||
if self.state in (HttpParser.states.HEADERS_COMPLETE,
|
if self.state in (HttpParser.states.HEADERS_COMPLETE,
|
||||||
HttpParser.states.RCVING_BODY,
|
HttpParser.states.RCVING_BODY,
|
||||||
HttpParser.states.COMPLETE) and \
|
HttpParser.states.COMPLETE) and \
|
||||||
|
@ -218,22 +237,22 @@ class HttpParser(object):
|
||||||
|
|
||||||
if b'content-length' in self.headers:
|
if b'content-length' in self.headers:
|
||||||
self.state = HttpParser.states.RCVING_BODY
|
self.state = HttpParser.states.RCVING_BODY
|
||||||
self.body += data
|
self.body += raw
|
||||||
if len(self.body) >= int(self.headers[b'content-length'][1]):
|
if len(self.body) >= int(self.headers[b'content-length'][1]):
|
||||||
self.state = HttpParser.states.COMPLETE
|
self.state = HttpParser.states.COMPLETE
|
||||||
elif self.is_chunked_encoded_response():
|
elif self.is_chunked_encoded_response():
|
||||||
if not self.chunk_parser:
|
if not self.chunk_parser:
|
||||||
self.chunk_parser = ChunkParser()
|
self.chunk_parser = ChunkParser()
|
||||||
self.chunk_parser.parse(data)
|
self.chunk_parser.parse(raw)
|
||||||
if self.chunk_parser.state == ChunkParser.states.COMPLETE:
|
if self.chunk_parser.state == ChunkParser.states.COMPLETE:
|
||||||
self.body = self.chunk_parser.body
|
self.body = self.chunk_parser.body
|
||||||
self.state = HttpParser.states.COMPLETE
|
self.state = HttpParser.states.COMPLETE
|
||||||
|
|
||||||
return False, b''
|
return False, b''
|
||||||
|
|
||||||
line, data = HttpParser.split(data)
|
line, raw = HttpParser.split(raw)
|
||||||
if line is False:
|
if line is False:
|
||||||
return line, data
|
return line, raw
|
||||||
|
|
||||||
if self.state == HttpParser.states.INITIALIZED:
|
if self.state == HttpParser.states.INITIALIZED:
|
||||||
self.process_line(line)
|
self.process_line(line)
|
||||||
|
@ -245,7 +264,7 @@ class HttpParser(object):
|
||||||
if self.state == HttpParser.states.LINE_RCVD and \
|
if self.state == HttpParser.states.LINE_RCVD and \
|
||||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||||
self.method == b'CONNECT' and \
|
self.method == b'CONNECT' and \
|
||||||
data == CRLF:
|
raw == CRLF:
|
||||||
self.state = HttpParser.states.COMPLETE
|
self.state = HttpParser.states.COMPLETE
|
||||||
|
|
||||||
# When raw request has ended with \r\n\r\n and no more http headers are expected
|
# When raw request has ended with \r\n\r\n and no more http headers are expected
|
||||||
|
@ -254,7 +273,7 @@ class HttpParser(object):
|
||||||
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
||||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||||
self.method != b'POST' and \
|
self.method != b'POST' and \
|
||||||
self.raw.endswith(CRLF * 2):
|
self.bytes.endswith(CRLF * 2):
|
||||||
self.state = HttpParser.states.COMPLETE
|
self.state = HttpParser.states.COMPLETE
|
||||||
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
elif self.state == HttpParser.states.HEADERS_COMPLETE and \
|
||||||
self.type == HttpParser.types.REQUEST_PARSER and \
|
self.type == HttpParser.types.REQUEST_PARSER and \
|
||||||
|
@ -262,13 +281,13 @@ class HttpParser(object):
|
||||||
(b'content-length' not in self.headers or
|
(b'content-length' not in self.headers or
|
||||||
(b'content-length' in self.headers and
|
(b'content-length' in self.headers and
|
||||||
int(self.headers[b'content-length'][1]) == 0)) and \
|
int(self.headers[b'content-length'][1]) == 0)) and \
|
||||||
self.raw.endswith(CRLF * 2):
|
self.bytes.endswith(CRLF * 2):
|
||||||
self.state = HttpParser.states.COMPLETE
|
self.state = HttpParser.states.COMPLETE
|
||||||
|
|
||||||
return len(data) > 0, data
|
return len(raw) > 0, raw
|
||||||
|
|
||||||
def process_line(self, data):
|
def process_line(self, raw):
|
||||||
line = data.split(SP)
|
line = raw.split(WHITESPACE)
|
||||||
if self.type == HttpParser.types.REQUEST_PARSER:
|
if self.type == HttpParser.types.REQUEST_PARSER:
|
||||||
self.method = line[0].upper()
|
self.method = line[0].upper()
|
||||||
self.url = urlparse.urlsplit(line[1])
|
self.url = urlparse.urlsplit(line[1])
|
||||||
|
@ -277,17 +296,18 @@ class HttpParser(object):
|
||||||
self.version = line[0]
|
self.version = line[0]
|
||||||
self.code = line[1]
|
self.code = line[1]
|
||||||
self.reason = b' '.join(line[2:])
|
self.reason = b' '.join(line[2:])
|
||||||
|
self.set_host_port()
|
||||||
self.state = HttpParser.states.LINE_RCVD
|
self.state = HttpParser.states.LINE_RCVD
|
||||||
|
|
||||||
def process_header(self, data):
|
def process_header(self, raw):
|
||||||
if len(data) == 0:
|
if len(raw) == 0:
|
||||||
if self.state == HttpParser.states.RCVING_HEADERS:
|
if self.state == HttpParser.states.RCVING_HEADERS:
|
||||||
self.state = HttpParser.states.HEADERS_COMPLETE
|
self.state = HttpParser.states.HEADERS_COMPLETE
|
||||||
elif self.state == HttpParser.states.LINE_RCVD:
|
elif self.state == HttpParser.states.LINE_RCVD:
|
||||||
self.state = HttpParser.states.RCVING_HEADERS
|
self.state = HttpParser.states.RCVING_HEADERS
|
||||||
else:
|
else:
|
||||||
self.state = HttpParser.states.RCVING_HEADERS
|
self.state = HttpParser.states.RCVING_HEADERS
|
||||||
parts = data.split(COLON)
|
parts = raw.split(COLON)
|
||||||
key = parts[0].strip()
|
key = parts[0].strip()
|
||||||
value = COLON.join(parts[1:]).strip()
|
value = COLON.join(parts[1:]).strip()
|
||||||
self.headers[key.lower()] = (key, value)
|
self.headers[key.lower()] = (key, value)
|
||||||
|
@ -331,13 +351,23 @@ class HttpParser(object):
|
||||||
return k + b': ' + v
|
return k + b': ' + v
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split(data):
|
def split(raw):
|
||||||
pos = data.find(CRLF)
|
pos = raw.find(CRLF)
|
||||||
if pos == -1:
|
if pos == -1:
|
||||||
return False, data
|
return False, raw
|
||||||
line = data[:pos]
|
line = raw[:pos]
|
||||||
data = data[pos + len(CRLF):]
|
raw = raw[pos + len(CRLF):]
|
||||||
return line, data
|
return line, raw
|
||||||
|
|
||||||
|
###################################################################################
|
||||||
|
# HttpParser was originally written to parse the incoming raw Http requests.
|
||||||
|
# Since request / response objects passed to HttpProtocolBasePlugin methods
|
||||||
|
# are also HttpParser objects, methods below were added to simplify developer API.
|
||||||
|
####################################################################################
|
||||||
|
|
||||||
|
def has_upstream_server(self):
|
||||||
|
"""Host field SHOULD be None for incoming local WebServer requests."""
|
||||||
|
return True if self.host is not None else False
|
||||||
|
|
||||||
|
|
||||||
class TcpConnection(object):
|
class TcpConnection(object):
|
||||||
|
@ -488,69 +518,224 @@ class ProxyRequestRejected(ProxyError):
|
||||||
return CRLF.join(pkt) if len(pkt) > 0 else None
|
return CRLF.join(pkt) if len(pkt) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
class HttpProxyPlugin(object):
|
class HttpProtocolConfig(object):
|
||||||
"""Base HttpProxy Plugin class."""
|
"""Holds various configuration values applicable to HttpProtocolHandler.
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def handle_request(self, request):
|
|
||||||
"""Handle client request (HttpParser).
|
|
||||||
|
|
||||||
Return optionally modified client request (HttpParser) object.
|
|
||||||
"""
|
|
||||||
return request
|
|
||||||
|
|
||||||
def handle_response(self, data):
|
|
||||||
"""Handle data chunks as received from the server."""
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class HttpProxyConfig(object):
|
|
||||||
"""Holds various configuration values applicable to HttpProxy.
|
|
||||||
|
|
||||||
This config class helps us avoid passing around bunch of key/value pairs across methods.
|
This config class helps us avoid passing around bunch of key/value pairs across methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
|
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||||
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, plugins=None):
|
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE,
|
||||||
|
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None):
|
||||||
self.auth_code = auth_code
|
self.auth_code = auth_code
|
||||||
self.server_recvbuf_size = server_recvbuf_size
|
self.server_recvbuf_size = server_recvbuf_size
|
||||||
self.client_recvbuf_size = client_recvbuf_size
|
self.client_recvbuf_size = client_recvbuf_size
|
||||||
self.pac_file = pac_file
|
self.pac_file = pac_file
|
||||||
|
self.pac_file_url_path = pac_file_url_path
|
||||||
if plugins is None:
|
if plugins is None:
|
||||||
plugins = DEFAULT_PLUGINS
|
plugins = DEFAULT_PLUGINS
|
||||||
self.plugins = plugins
|
self.plugins: Dict = plugins
|
||||||
|
|
||||||
|
|
||||||
class HttpProxy(threading.Thread):
|
class HttpProtocolBasePlugin(object):
|
||||||
"""HTTP proxy implementation.
|
"""Base HttpProtocolHandler Plugin class.
|
||||||
|
|
||||||
Accepts `Client` connection object and act as a proxy between client and server.
|
Implement various lifecycle event methods to customize behavior."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, client, config=None):
|
|
||||||
super(HttpProxy, self).__init__()
|
|
||||||
|
|
||||||
self.start_time = self.now()
|
|
||||||
self.last_activity = self.start_time
|
|
||||||
|
|
||||||
|
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||||
|
self.config = config
|
||||||
self.client = client
|
self.client = client
|
||||||
self.server = None
|
self.request = request
|
||||||
self.config = config if config else HttpProxyConfig()
|
|
||||||
|
|
||||||
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
|
def name(self) -> str:
|
||||||
|
"""A unique name for your plugin.
|
||||||
|
|
||||||
|
Defaults to name of the class. This helps plugin developers to directly
|
||||||
|
access a specific plugin by its name."""
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def get_descriptors(self):
|
||||||
|
return [], [], []
|
||||||
|
|
||||||
|
def flush_to_descriptors(self, w):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def read_from_descriptors(self, r):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_client_data(self, raw: bytes):
|
||||||
|
return raw
|
||||||
|
|
||||||
|
def on_request_complete(self):
|
||||||
|
"""Called right after client request parser has reached COMPLETE state."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_response_chunk(self, chunk: bytes):
|
||||||
|
"""Handle data chunks as received from the server.
|
||||||
|
|
||||||
|
Return optionally modified chunk to return back to client."""
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def access_log(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_client_connection_close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HttpProxyPlugin(HttpProtocolBasePlugin):
|
||||||
|
"""HttpProtocolHandler plugin which implements HttpProxy specifications."""
|
||||||
|
|
||||||
|
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||||
|
super(HttpProxyPlugin, self).__init__(config, client, request)
|
||||||
|
self.server = None
|
||||||
self.response = HttpParser(HttpParser.types.RESPONSE_PARSER)
|
self.response = HttpParser(HttpParser.types.RESPONSE_PARSER)
|
||||||
|
|
||||||
@staticmethod
|
def get_descriptors(self):
|
||||||
def now():
|
if not self.request.has_upstream_server():
|
||||||
return datetime.datetime.utcnow()
|
return [], [], []
|
||||||
|
|
||||||
def connection_inactive_for(self):
|
r, w, x = [], [], []
|
||||||
return (self.now() - self.last_activity).seconds
|
if self.server and not self.server.closed:
|
||||||
|
r.append(self.server.conn)
|
||||||
|
if self.server and not self.server.closed and self.server.has_buffer():
|
||||||
|
w.append(self.server.conn)
|
||||||
|
return r, w, x
|
||||||
|
|
||||||
def is_connection_inactive(self):
|
def flush_to_descriptors(self, w):
|
||||||
return self.connection_inactive_for() > 30
|
if not self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.server and not self.server.closed and self.server.conn in w:
|
||||||
|
logger.debug('Server is ready for writes, flushing server buffer')
|
||||||
|
self.server.flush()
|
||||||
|
|
||||||
|
def read_from_descriptors(self, r):
|
||||||
|
if not self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.server and not self.server.closed and self.server.conn in r:
|
||||||
|
logger.debug('Server is ready for reads, reading')
|
||||||
|
raw = self.server.recv(self.config.server_recvbuf_size)
|
||||||
|
# self.last_activity = HttpProtocolHandler.now()
|
||||||
|
if not raw:
|
||||||
|
logger.debug('Server closed connection, tearing down...')
|
||||||
|
return True
|
||||||
|
# parse incoming response packet
|
||||||
|
# only for non-https requests
|
||||||
|
if not self.request.method == b'CONNECT':
|
||||||
|
self.response.parse(raw)
|
||||||
|
else:
|
||||||
|
# Only purpose of increasing memory footprint is to print response length in access log
|
||||||
|
# Not worth it? Optimize to only persist lengths?
|
||||||
|
self.response.bytes += raw
|
||||||
|
# queue raw data for client
|
||||||
|
self.client.queue(raw)
|
||||||
|
|
||||||
|
def on_client_connection_close(self):
|
||||||
|
if not self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.server:
|
||||||
|
logger.debug(
|
||||||
|
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
|
||||||
|
if not self.server.closed:
|
||||||
|
self.server.close()
|
||||||
|
|
||||||
|
def on_client_data(self, raw):
|
||||||
|
if not self.request.has_upstream_server():
|
||||||
|
return raw
|
||||||
|
|
||||||
|
if self.server and not self.server.closed:
|
||||||
|
self.server.queue(raw)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return raw
|
||||||
|
|
||||||
|
def on_request_complete(self):
|
||||||
|
if not self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.authenticate(self.request.headers)
|
||||||
|
self.connect_upstream(self.request.host, self.request.port)
|
||||||
|
# for http connect methods (https requests)
|
||||||
|
# queue appropriate response for client
|
||||||
|
# notifying about established connection
|
||||||
|
if self.request.method == b'CONNECT':
|
||||||
|
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||||
|
# for general http requests, re-build request packet
|
||||||
|
# and queue for the server with appropriate headers
|
||||||
|
else:
|
||||||
|
self.server.queue(self.request.build(
|
||||||
|
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
|
||||||
|
b'keep-alive'],
|
||||||
|
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
|
||||||
|
))
|
||||||
|
|
||||||
|
def access_log(self):
|
||||||
|
if not self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
host, port = self.server.addr if self.server else (None, None)
|
||||||
|
if self.request.method == b'CONNECT':
|
||||||
|
logger.info(
|
||||||
|
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
|
||||||
|
text_(self.request.method), text_(host),
|
||||||
|
text_(port), len(self.response.bytes)))
|
||||||
|
elif self.request.method:
|
||||||
|
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
|
||||||
|
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
|
||||||
|
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
|
||||||
|
len(self.response.bytes)))
|
||||||
|
|
||||||
|
def authenticate(self, headers):
|
||||||
|
if self.config.auth_code:
|
||||||
|
if b'proxy-authorization' not in headers or \
|
||||||
|
headers[b'proxy-authorization'][1] != self.config.auth_code:
|
||||||
|
raise ProxyAuthenticationFailed()
|
||||||
|
|
||||||
|
def connect_upstream(self, host, port):
|
||||||
|
self.server = TcpServerConnection(host, port)
|
||||||
|
try:
|
||||||
|
logger.debug('Connecting to upstream %s:%s' % (host, port))
|
||||||
|
self.server.connect()
|
||||||
|
logger.debug('Connected to upstream %s:%s' % (host, port))
|
||||||
|
except Exception as e: # TimeoutError, socket.gaierror
|
||||||
|
self.server.closed = True
|
||||||
|
raise ProxyConnectionFailed(host, port, repr(e))
|
||||||
|
|
||||||
|
|
||||||
|
class HttpWebServerPlugin(HttpProtocolBasePlugin):
|
||||||
|
"""HttpProtocolHandler plugin which handles incoming requests to local webserver."""
|
||||||
|
|
||||||
|
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||||
|
super(HttpWebServerPlugin, self).__init__(config, client, request)
|
||||||
|
|
||||||
|
def on_request_complete(self):
|
||||||
|
if self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.config.pac_file and \
|
||||||
|
self.request.url.path == self.config.pac_file_url_path:
|
||||||
|
self.serve_pac_file()
|
||||||
|
else:
|
||||||
|
self.client.queue(CRLF.join([
|
||||||
|
b'HTTP/1.1 200 OK',
|
||||||
|
b'Server: proxy.py v%s' % version,
|
||||||
|
b'Content-Length: 42',
|
||||||
|
b'Connection: Close',
|
||||||
|
CRLF,
|
||||||
|
b'https://github.com/abhinavsingh/proxy.py'
|
||||||
|
]))
|
||||||
|
self.client.flush()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def access_log(self):
|
||||||
|
if self.request.has_upstream_server():
|
||||||
|
return
|
||||||
|
logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1],
|
||||||
|
text_(self.request.method), text_(self.request.build_url())))
|
||||||
|
|
||||||
def serve_pac_file(self):
|
def serve_pac_file(self):
|
||||||
self.client.queue(PAC_FILE_RESPONSE_PREFIX)
|
self.client.queue(PAC_FILE_RESPONSE_PREFIX)
|
||||||
|
@ -563,129 +748,103 @@ class HttpProxy(threading.Thread):
|
||||||
self.client.queue(self.config.pac_file)
|
self.client.queue(self.config.pac_file)
|
||||||
self.client.flush()
|
self.client.flush()
|
||||||
|
|
||||||
def access_log(self):
|
|
||||||
host, port = self.server.addr if self.server else (None, None)
|
class HttpProtocolHandler(threading.Thread):
|
||||||
if self.request.method == b'CONNECT':
|
"""HTTP, HTTPS, HTTP2, WebSockets protocol handler.
|
||||||
logger.info(
|
|
||||||
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
|
Accepts `Client` connection object, manages plugin invocations.
|
||||||
text_(self.request.method), text_(host),
|
"""
|
||||||
text_(port), len(self.response.raw)))
|
|
||||||
elif self.request.method:
|
def __init__(self, client, config=None):
|
||||||
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
|
super(HttpProtocolHandler, self).__init__()
|
||||||
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
|
|
||||||
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
|
self.start_time = self.now()
|
||||||
len(self.response.raw)))
|
self.last_activity = self.start_time
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
self.config = config if config else HttpProtocolConfig()
|
||||||
|
self.request = HttpParser(HttpParser.types.REQUEST_PARSER)
|
||||||
|
|
||||||
|
self.plugins: Dict[str, HttpProtocolBasePlugin] = {}
|
||||||
|
for klass in self.config.plugins:
|
||||||
|
instance = klass(self.config, self.client, self.request)
|
||||||
|
self.plugins[instance.name()] = instance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def now():
|
||||||
|
return datetime.datetime.utcnow()
|
||||||
|
|
||||||
|
def connection_inactive_for(self):
|
||||||
|
return (self.now() - self.last_activity).seconds
|
||||||
|
|
||||||
|
def is_connection_inactive(self):
|
||||||
|
return self.connection_inactive_for() > 30
|
||||||
|
|
||||||
def run_once(self):
|
def run_once(self):
|
||||||
"""Returns True if proxy must teardown."""
|
"""Returns True if proxy must teardown."""
|
||||||
# Prepare list of descriptors
|
# Prepare list of descriptors
|
||||||
rlist, wlist, xlist = [self.client.conn], [], []
|
read_desc, write_desc, err_desc = [self.client.conn], [], []
|
||||||
if self.client.has_buffer():
|
if self.client.has_buffer():
|
||||||
wlist.append(self.client.conn)
|
write_desc.append(self.client.conn)
|
||||||
if self.server and not self.server.closed:
|
|
||||||
rlist.append(self.server.conn)
|
|
||||||
if self.server and not self.server.closed and self.server.has_buffer():
|
|
||||||
wlist.append(self.server.conn)
|
|
||||||
|
|
||||||
r, w, x = select.select(rlist, wlist, xlist, 1)
|
for plugin in self.plugins.values():
|
||||||
|
plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors()
|
||||||
|
read_desc += plugin_read_desc
|
||||||
|
write_desc += plugin_write_desc
|
||||||
|
err_desc += plugin_err_desc
|
||||||
|
|
||||||
# Flush buffer from ready to write sockets
|
readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1)
|
||||||
if self.client.conn in w:
|
|
||||||
|
# Flush buffer for ready to write sockets
|
||||||
|
if self.client.conn in writable:
|
||||||
logger.debug('Client is ready for writes, flushing client buffer')
|
logger.debug('Client is ready for writes, flushing client buffer')
|
||||||
self.client.flush()
|
try:
|
||||||
if self.server and not self.server.closed and self.server.conn in w:
|
self.client.flush()
|
||||||
logger.debug('Server is ready for writes, flushing server buffer')
|
except BrokenPipeError:
|
||||||
self.server.flush()
|
logging.error('BrokenPipeError when flushing buffer for client')
|
||||||
|
return True
|
||||||
|
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
plugin.flush_to_descriptors(writable)
|
||||||
|
|
||||||
# Read from ready to read sockets
|
# Read from ready to read sockets
|
||||||
if self.client.conn in r:
|
if self.client.conn in readable:
|
||||||
logger.debug('Client is ready for reads, reading')
|
logger.debug('Client is ready for reads, reading')
|
||||||
data = self.client.recv(self.config.client_recvbuf_size)
|
client_data = self.client.recv(self.config.client_recvbuf_size)
|
||||||
self.last_activity = self.now()
|
self.last_activity = self.now()
|
||||||
if not data:
|
if not client_data:
|
||||||
logger.debug('Client closed connection, tearing down...')
|
logger.debug('Client closed connection, tearing down...')
|
||||||
return True
|
return True
|
||||||
try:
|
|
||||||
# Once we have connection to the server
|
plugin_index = 0
|
||||||
# we don't parse the http request packets
|
plugins = list(self.plugins.values())
|
||||||
# any further, instead just pipe incoming
|
while plugin_index < len(plugins) and client_data:
|
||||||
# data from client to server
|
client_data = plugins[plugin_index].on_client_data(client_data)
|
||||||
if self.server and not self.server.closed:
|
plugin_index += 1
|
||||||
self.server.queue(data)
|
|
||||||
else:
|
if client_data:
|
||||||
|
try:
|
||||||
# Parse http request
|
# Parse http request
|
||||||
self.request.parse(data)
|
self.request.parse(client_data)
|
||||||
if self.request.state == HttpParser.states.COMPLETE:
|
if self.request.state == HttpParser.states.COMPLETE:
|
||||||
logger.debug('Request parser is in state complete')
|
for plugin in self.plugins.values():
|
||||||
|
# TODO: Cleanup by not returning True for teardown cases
|
||||||
|
plugin_response = plugin.on_request_complete()
|
||||||
|
if type(plugin_response) is bool:
|
||||||
|
return True
|
||||||
|
except ProxyError as e: # ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected
|
||||||
|
# logger.exception(e)
|
||||||
|
response = e.response(self.request)
|
||||||
|
if response:
|
||||||
|
self.client.queue(response)
|
||||||
|
# But is client also ready for writes?
|
||||||
|
self.client.flush()
|
||||||
|
raise e
|
||||||
|
|
||||||
if self.config.auth_code:
|
for plugin in self.plugins.values():
|
||||||
if b'proxy-authorization' not in self.request.headers or \
|
teardown = plugin.read_from_descriptors(readable)
|
||||||
self.request.headers[b'proxy-authorization'][1] != self.config.auth_code:
|
if teardown:
|
||||||
raise ProxyAuthenticationFailed()
|
|
||||||
|
|
||||||
# Invoke HttpProxyPlugin.handle_request
|
|
||||||
for plugin in self.config.plugins:
|
|
||||||
self.request = plugin.handle_request(self.request)
|
|
||||||
|
|
||||||
if self.request.method == b'CONNECT':
|
|
||||||
host, port = self.request.url.path.split(COLON)
|
|
||||||
elif self.request.url:
|
|
||||||
host, port = self.request.url.hostname, self.request.url.port \
|
|
||||||
if self.request.url.port else 80
|
|
||||||
else:
|
|
||||||
raise Exception('Invalid request\n%s' % self.request.raw)
|
|
||||||
|
|
||||||
if host is None and self.config.pac_file:
|
|
||||||
self.serve_pac_file()
|
|
||||||
return True
|
|
||||||
|
|
||||||
self.server = TcpServerConnection(host, port)
|
|
||||||
try:
|
|
||||||
logger.debug('Connecting to server %s:%s' % (host, port))
|
|
||||||
self.server.connect()
|
|
||||||
logger.debug('Connected to server %s:%s' % (host, port))
|
|
||||||
except Exception as e: # TimeoutError, socket.gaierror
|
|
||||||
self.server.closed = True
|
|
||||||
raise ProxyConnectionFailed(host, port, repr(e))
|
|
||||||
|
|
||||||
# for http connect methods (https requests)
|
|
||||||
# queue appropriate response for client
|
|
||||||
# notifying about established connection
|
|
||||||
if self.request.method == b'CONNECT':
|
|
||||||
self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
|
||||||
# for usual http requests, re-build request packet
|
|
||||||
# and queue for the server with appropriate headers
|
|
||||||
else:
|
|
||||||
self.server.queue(self.request.build(
|
|
||||||
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
|
|
||||||
b'keep-alive'],
|
|
||||||
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
|
|
||||||
))
|
|
||||||
except (ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected) as e:
|
|
||||||
# logger.exception(e)
|
|
||||||
response = e.response(self.request)
|
|
||||||
if response:
|
|
||||||
self.client.queue(response)
|
|
||||||
# But is client also ready for writes?
|
|
||||||
self.client.flush()
|
|
||||||
raise e
|
|
||||||
if self.server and not self.server.closed and self.server.conn in r:
|
|
||||||
logger.debug('Server is ready for reads, reading')
|
|
||||||
data = self.server.recv(self.config.server_recvbuf_size)
|
|
||||||
self.last_activity = self.now()
|
|
||||||
if not data:
|
|
||||||
logger.debug('Server closed connection, tearing down...')
|
|
||||||
return True
|
return True
|
||||||
# parse incoming response packet
|
|
||||||
# only for non-https requests
|
|
||||||
if not self.request.method == b'CONNECT':
|
|
||||||
self.response.parse(data)
|
|
||||||
else:
|
|
||||||
# Only purpose of increasing memory footprint is to print response length in access log
|
|
||||||
# Not worth it? Optimize to only persist lengths?
|
|
||||||
self.response.raw += data
|
|
||||||
# queue data for client
|
|
||||||
self.client.queue(data)
|
|
||||||
|
|
||||||
# Teardown if client buffer is empty and connection is inactive
|
# Teardown if client buffer is empty and connection is inactive
|
||||||
if self.client.buffer_size() == 0:
|
if self.client.buffer_size() == 0:
|
||||||
|
@ -706,16 +865,17 @@ class HttpProxy(threading.Thread):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e))
|
logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e))
|
||||||
finally:
|
finally:
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
plugin.access_log()
|
||||||
|
|
||||||
self.client.close()
|
self.client.close()
|
||||||
logger.debug(
|
logger.debug('Closed client connection with pending '
|
||||||
'Closed client connection with pending client buffer size %d bytes' % self.client.buffer_size())
|
'client buffer size %d bytes' % self.client.buffer_size())
|
||||||
if self.server:
|
for plugin in self.plugins.values():
|
||||||
logger.debug(
|
plugin.on_client_connection_close()
|
||||||
'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size())
|
|
||||||
if not self.server.closed:
|
logger.debug('Closed proxy for connection %r '
|
||||||
self.server.close()
|
'at address %r' % (self.client.conn, self.client.addr))
|
||||||
self.access_log()
|
|
||||||
logger.debug('Closed proxy for connection %r at address %r' % (self.client.conn, self.client.addr))
|
|
||||||
|
|
||||||
|
|
||||||
class TcpServer(object):
|
class TcpServer(object):
|
||||||
|
@ -797,7 +957,7 @@ class MultiCoreRequestDispatcher(TcpServer):
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
def handle(self, client):
|
def handle(self, client):
|
||||||
self.worker_queue.put((Worker.operations.HTTP_PROXY, client))
|
self.worker_queue.put((Worker.operations.HTTP_PROTOCOL, client))
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info('Shutting down %d workers' % self.num_workers)
|
logger.info('Shutting down %d workers' % self.num_workers)
|
||||||
|
@ -815,7 +975,7 @@ class Worker(multiprocessing.Process):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
operations = namedtuple('WorkerOperations', (
|
operations = namedtuple('WorkerOperations', (
|
||||||
'HTTP_PROXY',
|
'HTTP_PROTOCOL',
|
||||||
'SHUTDOWN',
|
'SHUTDOWN',
|
||||||
))(1, 2)
|
))(1, 2)
|
||||||
|
|
||||||
|
@ -828,8 +988,8 @@ class Worker(multiprocessing.Process):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
op, payload = self.work_queue.get(True, 1)
|
op, payload = self.work_queue.get(True, 1)
|
||||||
if op == Worker.operations.HTTP_PROXY:
|
if op == Worker.operations.HTTP_PROTOCOL:
|
||||||
proxy = HttpProxy(payload, config=self.config)
|
proxy = HttpProtocolHandler(payload, config=self.config)
|
||||||
proxy.setDaemon(True)
|
proxy.setDaemon(True)
|
||||||
proxy.start()
|
proxy.start()
|
||||||
elif op == Worker.operations.SHUTDOWN:
|
elif op == Worker.operations.SHUTDOWN:
|
||||||
|
@ -852,6 +1012,22 @@ def set_open_file_limit(soft_limit):
|
||||||
logger.debug('Open file descriptor soft limit set to %d' % soft_limit)
|
logger.debug('Open file descriptor soft limit set to %d' % soft_limit)
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugins(plugins) -> List:
|
||||||
|
p = []
|
||||||
|
plugins = plugins.split(',')
|
||||||
|
for plugin in plugins:
|
||||||
|
plugin = plugin.strip()
|
||||||
|
if plugin == '':
|
||||||
|
continue
|
||||||
|
logging.debug('Loading plugin %s', plugin)
|
||||||
|
module_name, klass_name = plugin.rsplit('.', 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
klass = getattr(module, klass_name)
|
||||||
|
logging.info('%s initialized' % klass.__name__)
|
||||||
|
p.append(klass)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
def init_parser():
|
def init_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='proxy.py v%s' % __version__,
|
description='proxy.py v%s' % __version__,
|
||||||
|
@ -882,12 +1058,14 @@ def init_parser():
|
||||||
parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT,
|
parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT,
|
||||||
help='Default: 1024. Maximum number of files (TCP connections) '
|
help='Default: 1024. Maximum number of files (TCP connections) '
|
||||||
'that proxy.py can open concurrently.')
|
'that proxy.py can open concurrently.')
|
||||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
|
||||||
help='Default: 8899. Server port.')
|
|
||||||
parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE,
|
parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE,
|
||||||
help='A file (Proxy Auto Configuration) or string to serve when '
|
help='A file (Proxy Auto Configuration) or string to serve when '
|
||||||
'the server receives a direct file request.')
|
'the server receives a direct file request.')
|
||||||
parser.add_argument('--plugins', type=str, default=None, help='Comma separated plugins')
|
parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH,
|
||||||
|
help='Web server path to serve the PAC file.')
|
||||||
|
parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins')
|
||||||
|
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
||||||
|
help='Default: 8899. Server port.')
|
||||||
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
|
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||||
help='Default: 8 KB. Maximum amount of data received from the '
|
help='Default: 8 KB. Maximum amount of data received from the '
|
||||||
'server in a single recv() operation. Bump this '
|
'server in a single recv() operation. Bump this '
|
||||||
|
@ -898,18 +1076,6 @@ def init_parser():
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def load_plugins(plugins):
|
|
||||||
p = []
|
|
||||||
plugins = plugins.split(',')
|
|
||||||
for plugin in plugins:
|
|
||||||
module_name, klass_name = plugin.rsplit('.', 1)
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
klass = getattr(module, klass_name)
|
|
||||||
logging.info('%s initialized' % klass.__name__)
|
|
||||||
p.append(klass())
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if not PY3 and not UNDER_TEST:
|
if not PY3 and not UNDER_TEST:
|
||||||
print(
|
print(
|
||||||
|
@ -924,7 +1090,6 @@ def main():
|
||||||
|
|
||||||
parser = init_parser()
|
parser = init_parser()
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
|
|
||||||
logging.basicConfig(level=getattr(logging,
|
logging.basicConfig(level=getattr(logging,
|
||||||
{
|
{
|
||||||
'D': 'DEBUG',
|
'D': 'DEBUG',
|
||||||
|
@ -934,7 +1099,6 @@ def main():
|
||||||
'C': 'CRITICAL'
|
'C': 'CRITICAL'
|
||||||
}[args.log_level.upper()[0]]),
|
}[args.log_level.upper()[0]]),
|
||||||
format=args.log_format)
|
format=args.log_format)
|
||||||
plugins = load_plugins(args.plugins) if args.plugins else DEFAULT_PLUGINS
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set_open_file_limit(args.open_file_limit)
|
set_open_file_limit(args.open_file_limit)
|
||||||
|
@ -943,16 +1107,18 @@ def main():
|
||||||
if args.basic_auth:
|
if args.basic_auth:
|
||||||
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
|
auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth))
|
||||||
|
|
||||||
|
config = HttpProtocolConfig(auth_code=auth_code,
|
||||||
|
server_recvbuf_size=args.server_recvbuf_size,
|
||||||
|
client_recvbuf_size=args.client_recvbuf_size,
|
||||||
|
pac_file=args.pac_file,
|
||||||
|
pac_file_url_path=args.pac_file_url_path)
|
||||||
|
config.plugins = load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,%s' % args.plugins)
|
||||||
server = MultiCoreRequestDispatcher(hostname=args.hostname,
|
server = MultiCoreRequestDispatcher(hostname=args.hostname,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
backlog=args.backlog,
|
backlog=args.backlog,
|
||||||
ipv4=args.ipv4,
|
ipv4=args.ipv4,
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
config=HttpProxyConfig(auth_code=auth_code,
|
config=config)
|
||||||
server_recvbuf_size=args.server_recvbuf_size,
|
|
||||||
client_recvbuf_size=args.client_recvbuf_size,
|
|
||||||
pac_file=args.pac_file,
|
|
||||||
plugins=plugins))
|
|
||||||
server.run()
|
server.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
mock==3.0.5
|
|
48
tests.py
48
tests.py
|
@ -3,9 +3,9 @@
|
||||||
proxy.py
|
proxy.py
|
||||||
~~~~~~~~
|
~~~~~~~~
|
||||||
|
|
||||||
HTTP Proxy Server in Python.
|
HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python.
|
||||||
|
|
||||||
:copyright: (c) 2013-2018 by Abhinav Singh.
|
:copyright: (c) 2013-2020 by Abhinav Singh.
|
||||||
:license: BSD, see LICENSE for more details.
|
:license: BSD, see LICENSE for more details.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
@ -21,8 +21,8 @@ from threading import Thread
|
||||||
|
|
||||||
import proxy
|
import proxy
|
||||||
|
|
||||||
# logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
# format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
|
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
|
||||||
|
|
||||||
if sys.version_info[0] == 3: # Python3 specific imports
|
if sys.version_info[0] == 3: # Python3 specific imports
|
||||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
|
@ -123,7 +123,7 @@ class TestMultiCoreRequestDispatcher(unittest.TestCase):
|
||||||
tcp_server = None
|
tcp_server = None
|
||||||
tcp_thread = None
|
tcp_thread = None
|
||||||
|
|
||||||
@mock.patch.object(proxy, 'HttpProxy', side_effect=mock_tcp_proxy_side_effect)
|
@mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect)
|
||||||
def testHttpProxyConnection(self, mock_tcp_proxy):
|
def testHttpProxyConnection(self, mock_tcp_proxy):
|
||||||
try:
|
try:
|
||||||
self.tcp_port = get_available_port()
|
self.tcp_port = get_available_port()
|
||||||
|
@ -523,6 +523,7 @@ class TestProxy(unittest.TestCase):
|
||||||
http_server = None
|
http_server = None
|
||||||
http_server_port = None
|
http_server_port = None
|
||||||
http_server_thread = None
|
http_server_thread = None
|
||||||
|
config = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
@ -531,6 +532,8 @@ class TestProxy(unittest.TestCase):
|
||||||
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
|
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
|
||||||
cls.http_server_thread.setDaemon(True)
|
cls.http_server_thread.setDaemon(True)
|
||||||
cls.http_server_thread.start()
|
cls.http_server_thread.start()
|
||||||
|
cls.config = proxy.HttpProtocolConfig()
|
||||||
|
cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
@ -541,7 +544,7 @@ class TestProxy(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._conn = MockTcpConnection()
|
self._conn = MockTcpConnection()
|
||||||
self._addr = ('127.0.0.1', 54382)
|
self._addr = ('127.0.0.1', 54382)
|
||||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr))
|
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config)
|
||||||
|
|
||||||
@mock.patch('select.select')
|
@mock.patch('select.select')
|
||||||
@mock.patch('proxy.TcpServerConnection')
|
@mock.patch('proxy.TcpServerConnection')
|
||||||
|
@ -610,7 +613,7 @@ class TestProxy(unittest.TestCase):
|
||||||
proxy.CRLF
|
proxy.CRLF
|
||||||
]))
|
]))
|
||||||
self.proxy.run_once()
|
self.proxy.run_once()
|
||||||
self.assertFalse(self.proxy.server is None)
|
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
|
||||||
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||||
mock_server_connection.assert_called_once()
|
mock_server_connection.assert_called_once()
|
||||||
server.connect.assert_called_once()
|
server.connect.assert_called_once()
|
||||||
|
@ -659,9 +662,10 @@ class TestProxy(unittest.TestCase):
|
||||||
@mock.patch('select.select')
|
@mock.patch('select.select')
|
||||||
def test_proxy_authentication_failed(self, mock_select):
|
def test_proxy_authentication_failed(self, mock_select):
|
||||||
mock_select.return_value = ([self._conn], [], [])
|
mock_select.return_value = ([self._conn], [], [])
|
||||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
|
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||||
config=proxy.HttpProxyConfig(
|
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||||
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
|
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
|
||||||
|
config=config)
|
||||||
self.proxy.client.conn.queue(proxy.CRLF.join([
|
self.proxy.client.conn.queue(proxy.CRLF.join([
|
||||||
b'GET http://abhinavsingh.com HTTP/1.1',
|
b'GET http://abhinavsingh.com HTTP/1.1',
|
||||||
b'Host: abhinavsingh.com',
|
b'Host: abhinavsingh.com',
|
||||||
|
@ -673,14 +677,15 @@ class TestProxy(unittest.TestCase):
|
||||||
@mock.patch('select.select')
|
@mock.patch('select.select')
|
||||||
@mock.patch('proxy.TcpServerConnection')
|
@mock.patch('proxy.TcpServerConnection')
|
||||||
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
|
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
|
||||||
client = proxy.TcpClientConnection(self._conn, self._addr)
|
|
||||||
config = proxy.HttpProxyConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
|
||||||
|
|
||||||
mock_select.return_value = ([self._conn], [], [])
|
mock_select.return_value = ([self._conn], [], [])
|
||||||
server = mock_server_connection.return_value
|
server = mock_server_connection.return_value
|
||||||
server.connect.return_value = True
|
server.connect.return_value = True
|
||||||
|
|
||||||
self.proxy = proxy.HttpProxy(client, config=config)
|
client = proxy.TcpClientConnection(self._conn, self._addr)
|
||||||
|
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||||
|
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||||
|
|
||||||
|
self.proxy = proxy.HttpProtocolHandler(client, config=config)
|
||||||
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
|
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
|
||||||
self.proxy.run_once()
|
self.proxy.run_once()
|
||||||
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
|
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
|
||||||
|
@ -731,9 +736,10 @@ class TestProxy(unittest.TestCase):
|
||||||
server.connect.return_value = True
|
server.connect.return_value = True
|
||||||
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
|
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
|
||||||
|
|
||||||
self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr),
|
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
|
||||||
config=proxy.HttpProxyConfig(
|
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
|
||||||
auth_code=b'Basic %s' % base64.b64encode(b'user:pass')))
|
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
|
||||||
|
config=config)
|
||||||
self.proxy.client.conn.queue(proxy.CRLF.join([
|
self.proxy.client.conn.queue(proxy.CRLF.join([
|
||||||
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
|
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
|
||||||
b'Host: localhost:%d' % self.http_server_port,
|
b'Host: localhost:%d' % self.http_server_port,
|
||||||
|
@ -743,7 +749,7 @@ class TestProxy(unittest.TestCase):
|
||||||
proxy.CRLF
|
proxy.CRLF
|
||||||
]))
|
]))
|
||||||
self.proxy.run_once()
|
self.proxy.run_once()
|
||||||
self.assertFalse(self.proxy.server is None)
|
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
|
||||||
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
|
||||||
mock_server_connection.assert_called_once()
|
mock_server_connection.assert_called_once()
|
||||||
server.connect.assert_called_once()
|
server.connect.assert_called_once()
|
||||||
|
@ -780,15 +786,15 @@ class TestWorker(unittest.TestCase):
|
||||||
self.queue = multiprocessing.Queue()
|
self.queue = multiprocessing.Queue()
|
||||||
self.worker = proxy.Worker(self.queue)
|
self.worker = proxy.Worker(self.queue)
|
||||||
|
|
||||||
@mock.patch('proxy.HttpProxy')
|
@mock.patch('proxy.HttpProtocolHandler')
|
||||||
def test_shutdown_op(self, mock_http_proxy):
|
def test_shutdown_op(self, mock_http_proxy):
|
||||||
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
||||||
self.worker.run() # Worker should consume the prior shutdown operation
|
self.worker.run() # Worker should consume the prior shutdown operation
|
||||||
self.assertFalse(mock_http_proxy.called)
|
self.assertFalse(mock_http_proxy.called)
|
||||||
|
|
||||||
@mock.patch('proxy.HttpProxy')
|
@mock.patch('proxy.HttpProtocolHandler')
|
||||||
def test_spawns_http_proxy_threads(self, mock_http_proxy):
|
def test_spawns_http_proxy_threads(self, mock_http_proxy):
|
||||||
self.queue.put((proxy.Worker.operations.HTTP_PROXY, None))
|
self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None))
|
||||||
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
|
||||||
self.worker.run()
|
self.worker.run()
|
||||||
self.assertTrue(mock_http_proxy.called)
|
self.assertTrue(mock_http_proxy.called)
|
||||||
|
|
Loading…
Reference in New Issue