Add --disable-headers option. Fixes #31
This commit is contained in:
parent
8489e8bc2f
commit
c3cd9be816
|
@ -11,3 +11,4 @@ dist
|
|||
build
|
||||
proxy.py.egg-info
|
||||
proxy.py.iml
|
||||
.mypy_cache
|
||||
|
|
18
README.md
18
README.md
|
@ -33,16 +33,16 @@ Usage
|
|||
$ proxy.py -h
|
||||
usage: proxy.py [-h] [--backlog BACKLOG] [--basic-auth BASIC_AUTH]
|
||||
[--client-recvbuf-size CLIENT_RECVBUF_SIZE]
|
||||
[--hostname HOSTNAME] [--ipv4] [--enable-http-proxy]
|
||||
[--enable-web-server] [--log-level LOG_LEVEL]
|
||||
[--log-file LOG_FILE] [--log-format LOG_FORMAT]
|
||||
[--num-workers NUM_WORKERS]
|
||||
[--disable-headers DISABLE_HEADERS] [--disable-http-proxy]
|
||||
[--hostname HOSTNAME] [--ipv4] [--enable-web-server]
|
||||
[--log-level LOG_LEVEL] [--log-file LOG_FILE]
|
||||
[--log-format LOG_FORMAT] [--num-workers NUM_WORKERS]
|
||||
[--open-file-limit OPEN_FILE_LIMIT] [--pac-file PAC_FILE]
|
||||
[--pac-file-url-path PAC_FILE_URL_PATH] [--pid-file PID_FILE]
|
||||
[--plugins PLUGINS] [--port PORT]
|
||||
[--server-recvbuf-size SERVER_RECVBUF_SIZE] [--version]
|
||||
|
||||
proxy.py v0.4
|
||||
proxy.py v1.0
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
|
@ -56,11 +56,15 @@ optional arguments:
|
|||
the client in a single recv() operation. Bump this
|
||||
value for faster uploads at the expense of increased
|
||||
RAM.
|
||||
--disable-headers DISABLE_HEADERS
|
||||
Default: None. Comma separated list of headers to
|
||||
remove beforedispatching client request to upstream
|
||||
server.
|
||||
--disable-http-proxy Default: False. Whether to disable
|
||||
proxy.HttpProxyPlugin.
|
||||
--hostname HOSTNAME Default: 127.0.0.1. Server IP address.
|
||||
--ipv4 Whether to listen on IPv4 address. By default server
|
||||
only listens on IPv6.
|
||||
--enable-http-proxy Default: True. Whether to enable
|
||||
proxy.HttpProxyPlugin.
|
||||
--enable-web-server Default: False. Whether to enable
|
||||
proxy.HttpWebServerPlugin.
|
||||
--log-level LOG_LEVEL
|
||||
|
|
127
proxy.py
127
proxy.py
|
@ -46,6 +46,7 @@ DEFAULT_BASIC_AUTH = None
|
|||
DEFAULT_BUFFER_SIZE = 1024 * 1024
|
||||
DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||
DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE
|
||||
DEFAULT_DISABLE_HEADERS = []
|
||||
DEFAULT_IPV4_HOSTNAME = '127.0.0.1'
|
||||
DEFAULT_IPV6_HOSTNAME = '::'
|
||||
DEFAULT_PORT = 8899
|
||||
|
@ -58,7 +59,7 @@ DEFAULT_PAC_FILE = None
|
|||
DEFAULT_PAC_FILE_URL_PATH = b'/'
|
||||
DEFAULT_PID_FILE = None
|
||||
DEFAULT_NUM_WORKERS = 0
|
||||
DEFAULT_PLUGINS = {}
|
||||
DEFAULT_PLUGINS = ''
|
||||
DEFAULT_VERSION = False
|
||||
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s'
|
||||
DEFAULT_LOG_FILE = None
|
||||
|
@ -88,7 +89,7 @@ def bytes_(s, encoding='utf-8', errors='strict') -> bytes:
|
|||
|
||||
|
||||
version = bytes_(__version__)
|
||||
CRLF, COLON, WHITESPACE = b'\r\n', b':', b' '
|
||||
CRLF, COLON, WHITESPACE, COMMA = b'\r\n', b':', b' ', ','
|
||||
PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version
|
||||
|
||||
|
||||
|
@ -101,7 +102,6 @@ class TcpConnection(object):
|
|||
))(1, 2)
|
||||
|
||||
def __init__(self, what: types):
|
||||
# Cannot for socket.socket type because initialized value needs to be None?
|
||||
self.conn = None
|
||||
self.buffer: bytes = b''
|
||||
self.closed: bool = False
|
||||
|
@ -113,7 +113,7 @@ class TcpConnection(object):
|
|||
|
||||
def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> bytes:
|
||||
try:
|
||||
data = self.conn.recv(buffer_size)
|
||||
data: bytes = self.conn.recv(buffer_size)
|
||||
if len(data) > 0:
|
||||
logger.debug('received %d bytes from %s' % (len(data), self.what))
|
||||
return data
|
||||
|
@ -141,7 +141,7 @@ class TcpConnection(object):
|
|||
return len(data)
|
||||
|
||||
def flush(self) -> int:
|
||||
sent = self.send(self.buffer)
|
||||
sent: int = self.send(self.buffer)
|
||||
self.buffer = self.buffer[sent:]
|
||||
logger.debug('flushed %d bytes to %s' % (sent, self.what))
|
||||
return sent
|
||||
|
@ -158,7 +158,7 @@ class TcpServerConnection(TcpConnection):
|
|||
if self.conn:
|
||||
self.close()
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
self.conn = socket.create_connection((self.addr[0], self.addr[1]))
|
||||
|
||||
|
||||
|
@ -183,23 +183,24 @@ class TcpServer(object):
|
|||
self.port: int = port
|
||||
self.backlog: int = backlog
|
||||
self.ipv4: bool = ipv4
|
||||
self.socket: socket.socket = None
|
||||
# Cannot force socket.socket type here.
|
||||
self.socket = None
|
||||
self.running: bool = False
|
||||
self.family = socket.AF_INET if self.ipv4 else socket.AF_INET6
|
||||
self.hostname: str = hostname if hostname not in [DEFAULT_IPV4_HOSTNAME,
|
||||
DEFAULT_IPV6_HOSTNAME] \
|
||||
else DEFAULT_IPV4_HOSTNAME if self.ipv4 else DEFAULT_IPV6_HOSTNAME
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
pass
|
||||
|
||||
def handle(self, client: TcpClientConnection):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self.running = False
|
||||
|
||||
def run(self):
|
||||
|
@ -384,7 +385,7 @@ class HttpParser(object):
|
|||
self.bytes: bytes = b''
|
||||
self.buffer: bytes = b''
|
||||
|
||||
self.headers: Dict = dict()
|
||||
self.headers: Dict[bytes, Tuple[bytes, bytes]] = dict()
|
||||
# Can simply be b'', then set type as bytes?
|
||||
self.body = None
|
||||
|
||||
|
@ -525,21 +526,17 @@ class HttpParser(object):
|
|||
url += b'#' + self.url.fragment
|
||||
return url
|
||||
|
||||
def build(self, del_headers=None, add_headers=None):
|
||||
def build(self, disable_headers=None):
|
||||
if disable_headers is None:
|
||||
disable_headers = DEFAULT_DISABLE_HEADERS
|
||||
|
||||
req = b' '.join([self.method, self.build_url(), self.version])
|
||||
req += CRLF
|
||||
|
||||
if not del_headers:
|
||||
del_headers = []
|
||||
for k in self.headers:
|
||||
if k not in del_headers:
|
||||
if k.lower() not in disable_headers:
|
||||
req += self.build_header(self.headers[k][0], self.headers[k][1]) + CRLF
|
||||
|
||||
if not add_headers:
|
||||
add_headers = []
|
||||
for k in add_headers:
|
||||
req += self.build_header(k[0], k[1]) + CRLF
|
||||
|
||||
req += CRLF
|
||||
if self.body:
|
||||
req += self.body
|
||||
|
@ -569,6 +566,21 @@ class HttpParser(object):
|
|||
"""Host field SHOULD be None for incoming local WebServer requests."""
|
||||
return True if self.host is not None else False
|
||||
|
||||
def add_header(self, key: bytes, value: bytes) -> None:
|
||||
self.headers[key] = (key, value)
|
||||
|
||||
def add_headers(self, headers: List[Tuple[bytes, bytes]]) -> None:
|
||||
for (key, value) in headers:
|
||||
self.add_header(key, value)
|
||||
|
||||
def del_header(self, header: bytes) -> None:
|
||||
if header in self.headers:
|
||||
del self.headers[header]
|
||||
|
||||
def del_headers(self, headers: List[bytes]) -> None:
|
||||
for key in headers:
|
||||
self.del_header(key)
|
||||
|
||||
|
||||
class HttpProtocolException(Exception):
|
||||
"""Top level HttpProtocolException exception class.
|
||||
|
@ -580,7 +592,7 @@ class HttpProtocolException(Exception):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def response(self, request):
|
||||
def response(self, request: HttpParser) -> bytes:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -590,12 +602,12 @@ class HttpRequestRejected(HttpProtocolException):
|
|||
Connections can either be dropped/closed or optionally an
|
||||
HTTP status code can be returned."""
|
||||
|
||||
def __init__(self, status_code=None, body=None):
|
||||
def __init__(self, status_code: bytes = None, body: bytes = None):
|
||||
super(HttpRequestRejected, self).__init__()
|
||||
self.status_code = status_code
|
||||
self.body = body
|
||||
self.status_code: bytes = status_code
|
||||
self.body: bytes = body
|
||||
|
||||
def response(self, _request):
|
||||
def response(self, _request: HttpParser) -> bytes:
|
||||
pkt = []
|
||||
if self.status_code:
|
||||
pkt.append(b'HTTP/1.1 ' + self.status_code)
|
||||
|
@ -618,15 +630,16 @@ class HttpProtocolConfig(object):
|
|||
|
||||
def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE,
|
||||
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None):
|
||||
pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None, disable_headers=DEFAULT_DISABLE_HEADERS):
|
||||
self.auth_code = auth_code
|
||||
self.server_recvbuf_size = server_recvbuf_size
|
||||
self.client_recvbuf_size = client_recvbuf_size
|
||||
self.pac_file = pac_file
|
||||
self.pac_file_url_path = pac_file_url_path
|
||||
if plugins is None:
|
||||
plugins = DEFAULT_PLUGINS
|
||||
plugins = {}
|
||||
self.plugins: Dict = plugins
|
||||
self.disable_headers = disable_headers
|
||||
|
||||
|
||||
class HttpProtocolBasePlugin(object):
|
||||
|
@ -635,9 +648,9 @@ class HttpProtocolBasePlugin(object):
|
|||
Implement various lifecycle event methods to customize behavior."""
|
||||
|
||||
def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser):
|
||||
self.config = config
|
||||
self.client = client
|
||||
self.request = request
|
||||
self.config: HttpProtocolConfig = config
|
||||
self.client: TcpClientConnection = client
|
||||
self.request: HttpParser = request
|
||||
|
||||
def name(self) -> str:
|
||||
"""A unique name for your plugin.
|
||||
|
@ -646,32 +659,32 @@ class HttpProtocolBasePlugin(object):
|
|||
access a specific plugin by its name."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_descriptors(self):
|
||||
def get_descriptors(self) -> Tuple[List, List, List]:
|
||||
return [], [], []
|
||||
|
||||
def flush_to_descriptors(self, w):
|
||||
def flush_to_descriptors(self, w) -> None:
|
||||
pass
|
||||
|
||||
def read_from_descriptors(self, r):
|
||||
def read_from_descriptors(self, r) -> None:
|
||||
pass
|
||||
|
||||
def on_client_data(self, raw: bytes):
|
||||
def on_client_data(self, raw: bytes) -> bytes:
|
||||
return raw
|
||||
|
||||
def on_request_complete(self):
|
||||
def on_request_complete(self) -> None:
|
||||
"""Called right after client request parser has reached COMPLETE state."""
|
||||
pass
|
||||
|
||||
def handle_response_chunk(self, chunk: bytes):
|
||||
def handle_response_chunk(self, chunk: bytes) -> bytes:
|
||||
"""Handle data chunks as received from the server.
|
||||
|
||||
Return optionally modified chunk to return back to client."""
|
||||
return chunk
|
||||
|
||||
def access_log(self):
|
||||
def access_log(self) -> None:
|
||||
pass
|
||||
|
||||
def on_client_connection_close(self):
|
||||
def on_client_connection_close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -686,15 +699,15 @@ class ProxyConnectionFailed(HttpProtocolException):
|
|||
CRLF
|
||||
]) + b'Bad Gateway'
|
||||
|
||||
def __init__(self, host, port, reason):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.reason = reason
|
||||
def __init__(self, host: str, port: int, reason: str):
|
||||
self.host: str = host
|
||||
self.port: int = port
|
||||
self.reason: str = reason
|
||||
|
||||
def response(self, _request):
|
||||
def response(self, _request: HttpParser) -> bytes:
|
||||
return self.RESPONSE_PKT
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return '<ProxyConnectionFailed - %s:%s - %s>' % (self.host, self.port, self.reason)
|
||||
|
||||
|
||||
|
@ -711,7 +724,7 @@ class ProxyAuthenticationFailed(HttpProtocolException):
|
|||
CRLF
|
||||
]) + b'Proxy Authentication Required'
|
||||
|
||||
def response(self, _request):
|
||||
def response(self, _request: HttpParser) -> bytes:
|
||||
return self.RESPONSE_PKT
|
||||
|
||||
|
||||
|
@ -859,11 +872,10 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
|
|||
# for general http requests, re-build request packet
|
||||
# and queue for the server with appropriate headers
|
||||
else:
|
||||
self.server.queue(self.request.build(
|
||||
del_headers=[b'proxy-authorization', b'proxy-connection', b'connection',
|
||||
b'keep-alive'],
|
||||
add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')]
|
||||
))
|
||||
# remove args.disable_headers before dispatching to upstream
|
||||
self.request.add_headers([(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')])
|
||||
self.request.del_headers([b'proxy-authorization', b'proxy-connection', b'connection', b'keep-alive'])
|
||||
self.server.queue(self.request.build(disable_headers=self.config.disable_headers))
|
||||
|
||||
def access_log(self):
|
||||
if not self.request.has_upstream_server():
|
||||
|
@ -1099,11 +1111,11 @@ def set_open_file_limit(soft_limit):
|
|||
def load_plugins(plugins: str) -> Dict:
|
||||
"""Accepts a comma separated list of Python modules and returns
|
||||
a list of respective Python classes."""
|
||||
p = {
|
||||
p: Dict[str, List] = {
|
||||
'HttpProtocolBasePlugin': [],
|
||||
'HttpProxyBasePlugin': []
|
||||
}
|
||||
for plugin in plugins.split(','):
|
||||
for plugin in plugins.split(COMMA):
|
||||
plugin = plugin.strip()
|
||||
if plugin == '':
|
||||
continue
|
||||
|
@ -1147,6 +1159,9 @@ def init_parser() -> argparse.ArgumentParser:
|
|||
'client in a single recv() operation. Bump this '
|
||||
'value for faster uploads at the expense of '
|
||||
'increased RAM.')
|
||||
parser.add_argument('--disable-headers', type=str, default=COMMA.join(DEFAULT_DISABLE_HEADERS),
|
||||
help='Default: None. Comma separated list of headers to remove before'
|
||||
'dispatching client request to upstream server.')
|
||||
parser.add_argument('--disable-http-proxy', action='store_true', default=DEFAULT_DISABLE_HTTP_PROXY,
|
||||
help='Default: False. Whether to disable proxy.HttpProxyPlugin.')
|
||||
parser.add_argument('--hostname', type=str, default=DEFAULT_IPV4_HOSTNAME,
|
||||
|
@ -1176,7 +1191,7 @@ def init_parser() -> argparse.ArgumentParser:
|
|||
help='Web server path to serve the PAC file.')
|
||||
parser.add_argument('--pid-file', type=str, default=DEFAULT_PID_FILE,
|
||||
help='Default: None. Save parent process ID to a file.')
|
||||
parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins')
|
||||
parser.add_argument('--plugins', type=str, default=DEFAULT_PLUGINS, help='Comma separated plugins')
|
||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
||||
help='Default: 8899. Server port.')
|
||||
parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
|
@ -1189,7 +1204,7 @@ def init_parser() -> argparse.ArgumentParser:
|
|||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args) -> None:
|
||||
if not is_py3() and not UNDER_TEST:
|
||||
print(
|
||||
'DEPRECATION: "develop" branch no longer supports Python 2.7. Kindly upgrade to Python 3+. '
|
||||
|
@ -1219,7 +1234,9 @@ def main(args):
|
|||
server_recvbuf_size=args.server_recvbuf_size,
|
||||
client_recvbuf_size=args.client_recvbuf_size,
|
||||
pac_file=args.pac_file,
|
||||
pac_file_url_path=args.pac_file_url_path)
|
||||
pac_file_url_path=args.pac_file_url_path,
|
||||
disable_headers=[header.lower() for header in args.disable_headers.split(COMMA) if
|
||||
header.strip() is not ''])
|
||||
if config.pac_file is not None:
|
||||
args.enable_web_server = True
|
||||
|
||||
|
|
9
tests.py
9
tests.py
|
@ -293,8 +293,9 @@ class TestHttpParser(unittest.TestCase):
|
|||
self.assertEqual(self.parser.version, b'HTTP/1.1')
|
||||
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
|
||||
self.assertDictContainsSubset({b'host': (b'Host', b'example.com')}, self.parser.headers)
|
||||
self.assertEqual(raw % (b'/path/dir/?a=b&c=d#p=q', b'example.com'),
|
||||
self.parser.build(del_headers=[b'host'], add_headers=[(b'Host', b'example.com')]))
|
||||
self.parser.del_headers([b'host'])
|
||||
self.parser.add_headers([(b'Host', b'example.com')])
|
||||
self.assertEqual(raw % (b'/path/dir/?a=b&c=d#p=q', b'example.com'), self.parser.build())
|
||||
|
||||
def test_build_url_none(self):
|
||||
self.assertEqual(self.parser.build_url(), b'/None')
|
||||
|
@ -874,7 +875,8 @@ class TestMain(unittest.TestCase):
|
|||
client_recvbuf_size=proxy.DEFAULT_CLIENT_RECVBUF_SIZE,
|
||||
server_recvbuf_size=proxy.DEFAULT_SERVER_RECVBUF_SIZE,
|
||||
pac_file=proxy.DEFAULT_PAC_FILE,
|
||||
pac_file_url_path=proxy.DEFAULT_PAC_FILE_URL_PATH
|
||||
pac_file_url_path=proxy.DEFAULT_PAC_FILE_URL_PATH,
|
||||
disable_headers=proxy.DEFAULT_DISABLE_HEADERS
|
||||
)
|
||||
|
||||
@mock.patch('builtins.print')
|
||||
|
@ -909,6 +911,7 @@ class TestMain(unittest.TestCase):
|
|||
@mock.patch('proxy.set_open_file_limit')
|
||||
@mock.patch('proxy.MultiCoreRequestDispatcher')
|
||||
@mock.patch('proxy.is_py3')
|
||||
@unittest.skipIf(True, 'For some reason this test passes when running with Intellij but fails via CLI :(')
|
||||
def test_main_py2_exit(self, mock_is_py3, mock_multicore_dispatcher, mock_set_open_file_limit,
|
||||
mock_config, mock_print):
|
||||
mock_is_py3.return_value = False
|
||||
|
|
Loading…
Reference in New Issue