329 lines
12 KiB
Python
329 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import re
|
|
import gzip
|
|
import time
|
|
import socket
|
|
import logging
|
|
import mimetypes
|
|
|
|
from typing import List, Tuple, Optional, Dict, Union, Any, Pattern
|
|
|
|
from ...common.constants import DEFAULT_STATIC_SERVER_DIR, PROXY_AGENT_HEADER_VALUE
|
|
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
|
|
from ...common.constants import DEFAULT_MIN_COMPRESSION_LIMIT, DEFAULT_WEB_ACCESS_LOG_FORMAT
|
|
from ...common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response
|
|
from ...common.types import Readables, Writables
|
|
from ...common.flag import flags
|
|
|
|
from ..exception import HttpProtocolException
|
|
from ..websocket import WebsocketFrame, websocketOpcodes
|
|
from ..parser import HttpParser, httpParserStates, httpParserTypes, httpStatusCodes
|
|
from ..plugin import HttpProtocolHandlerPlugin
|
|
|
|
from .plugin import HttpWebServerBasePlugin
|
|
from .protocols import httpProtocolTypes
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
flags.add_argument(
|
|
'--enable-web-server',
|
|
action='store_true',
|
|
default=DEFAULT_ENABLE_WEB_SERVER,
|
|
help='Default: False. Whether to enable proxy.HttpWebServerPlugin.',
|
|
)
|
|
|
|
flags.add_argument(
|
|
'--enable-static-server',
|
|
action='store_true',
|
|
default=DEFAULT_ENABLE_STATIC_SERVER,
|
|
help='Default: False. Enable inbuilt static file server. '
|
|
'Optionally, also use --static-server-dir to serve static content '
|
|
'from custom directory. By default, static file server serves '
|
|
'out of installed proxy.py python module folder.',
|
|
)
|
|
|
|
flags.add_argument(
|
|
'--static-server-dir',
|
|
type=str,
|
|
default=DEFAULT_STATIC_SERVER_DIR,
|
|
help='Default: "public" folder in directory where proxy.py is placed. '
|
|
'This option is only applicable when static server is also enabled. '
|
|
'See --enable-static-server.',
|
|
)
|
|
|
|
flags.add_argument(
|
|
'--min-compression-length',
|
|
type=int,
|
|
default=DEFAULT_MIN_COMPRESSION_LIMIT,
|
|
help='Default: ' + str(DEFAULT_MIN_COMPRESSION_LIMIT) + ' bytes. ' +
|
|
'Sets the minimum length of a response that will be compressed (gzipped).',
|
|
)
|
|
|
|
|
|
class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
|
|
"""HttpProtocolHandler plugin which handles incoming requests to local web server."""
|
|
|
|
DEFAULT_404_RESPONSE = memoryview(
|
|
build_http_response(
|
|
httpStatusCodes.NOT_FOUND,
|
|
reason=b'NOT FOUND',
|
|
headers={
|
|
b'Server': PROXY_AGENT_HEADER_VALUE,
|
|
b'Content-Length': b'0',
|
|
b'Connection': b'close',
|
|
},
|
|
),
|
|
)
|
|
|
|
DEFAULT_501_RESPONSE = memoryview(
|
|
build_http_response(
|
|
httpStatusCodes.NOT_IMPLEMENTED,
|
|
reason=b'NOT IMPLEMENTED',
|
|
headers={
|
|
b'Server': PROXY_AGENT_HEADER_VALUE,
|
|
b'Content-Length': b'0',
|
|
b'Connection': b'close',
|
|
},
|
|
),
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*args: Any, **kwargs: Any,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.start_time: float = time.time()
|
|
self.pipeline_request: Optional[HttpParser] = None
|
|
self.switched_protocol: Optional[int] = None
|
|
self.routes: Dict[int, Dict[Pattern[str], HttpWebServerBasePlugin]] = {
|
|
httpProtocolTypes.HTTP: {},
|
|
httpProtocolTypes.HTTPS: {},
|
|
httpProtocolTypes.WEBSOCKET: {},
|
|
}
|
|
self.route: Optional[HttpWebServerBasePlugin] = None
|
|
|
|
self.plugins: Dict[str, HttpWebServerBasePlugin] = {}
|
|
if b'HttpWebServerBasePlugin' in self.flags.plugins:
|
|
for klass in self.flags.plugins[b'HttpWebServerBasePlugin']:
|
|
instance: HttpWebServerBasePlugin = klass(
|
|
self.uid,
|
|
self.flags,
|
|
self.client,
|
|
self.event_queue,
|
|
)
|
|
self.plugins[instance.name()] = instance
|
|
for (protocol, route) in instance.routes():
|
|
self.routes[protocol][re.compile(route)] = instance
|
|
|
|
def encryption_enabled(self) -> bool:
|
|
return self.flags.keyfile is not None and \
|
|
self.flags.certfile is not None
|
|
|
|
@staticmethod
|
|
def read_and_build_static_file_response(path: str, min_compression_limit: int) -> memoryview:
|
|
try:
|
|
with open(path, 'rb') as f:
|
|
content = f.read()
|
|
content_type = mimetypes.guess_type(path)[0]
|
|
if content_type is None:
|
|
content_type = 'text/plain'
|
|
headers = {
|
|
b'Content-Type': bytes_(content_type),
|
|
b'Cache-Control': b'max-age=86400',
|
|
b'Connection': b'close',
|
|
}
|
|
do_compress = len(content) > min_compression_limit
|
|
if do_compress:
|
|
headers.update({
|
|
b'Content-Encoding': b'gzip',
|
|
})
|
|
return memoryview(
|
|
build_http_response(
|
|
httpStatusCodes.OK,
|
|
reason=b'OK',
|
|
headers=headers,
|
|
body=gzip.compress(content) if do_compress else content,
|
|
),
|
|
)
|
|
except FileNotFoundError:
|
|
return HttpWebServerPlugin.DEFAULT_404_RESPONSE
|
|
|
|
def try_upgrade(self) -> bool:
|
|
if self.request.has_header(b'connection') and \
|
|
self.request.header(b'connection').lower() == b'upgrade':
|
|
if self.request.has_header(b'upgrade') and \
|
|
self.request.header(b'upgrade').lower() == b'websocket':
|
|
self.client.queue(
|
|
memoryview(
|
|
build_websocket_handshake_response(
|
|
WebsocketFrame.key_to_accept(
|
|
self.request.header(b'Sec-WebSocket-Key'),
|
|
),
|
|
),
|
|
),
|
|
)
|
|
self.switched_protocol = httpProtocolTypes.WEBSOCKET
|
|
else:
|
|
self.client.queue(self.DEFAULT_501_RESPONSE)
|
|
return True
|
|
return False
|
|
|
|
def on_request_complete(self) -> Union[socket.socket, bool]:
|
|
if self.request.has_host():
|
|
return False
|
|
|
|
path = self.request.path or b'/'
|
|
|
|
# If a websocket route exists for the path, try upgrade
|
|
for route in self.routes[httpProtocolTypes.WEBSOCKET]:
|
|
match = route.match(text_(path))
|
|
if match:
|
|
self.route = self.routes[httpProtocolTypes.WEBSOCKET][route]
|
|
|
|
# Connection upgrade
|
|
teardown = self.try_upgrade()
|
|
if teardown:
|
|
return True
|
|
|
|
# For upgraded connections, nothing more to do
|
|
if self.switched_protocol:
|
|
# Invoke plugin.on_websocket_open
|
|
self.route.on_websocket_open()
|
|
return False
|
|
|
|
break
|
|
|
|
# Routing for Http(s) requests
|
|
protocol = httpProtocolTypes.HTTPS \
|
|
if self.encryption_enabled() else \
|
|
httpProtocolTypes.HTTP
|
|
for route in self.routes[protocol]:
|
|
match = route.match(text_(path))
|
|
if match:
|
|
self.route = self.routes[protocol][route]
|
|
self.route.handle_request(self.request)
|
|
if self.request.has_header(b'connection') and \
|
|
self.request.header(b'connection').lower() == b'close':
|
|
return True
|
|
return False
|
|
|
|
# No-route found, try static serving if enabled
|
|
if self.flags.enable_static_server:
|
|
path = text_(path).split('?')[0]
|
|
self.client.queue(
|
|
self.read_and_build_static_file_response(
|
|
self.flags.static_server_dir + path,
|
|
self.flags.min_compression_limit,
|
|
),
|
|
)
|
|
return True
|
|
|
|
# Catch all unhandled web server requests, return 404
|
|
self.client.queue(self.DEFAULT_404_RESPONSE)
|
|
return True
|
|
|
|
def get_descriptors(
|
|
self,
|
|
) -> Tuple[List[socket.socket], List[socket.socket]]:
|
|
r, w = [], []
|
|
for plugin in self.plugins.values():
|
|
r1, w1 = plugin.get_descriptors()
|
|
r.extend(r1)
|
|
w.extend(w1)
|
|
return r, w
|
|
|
|
def write_to_descriptors(self, w: Writables) -> bool:
|
|
for plugin in self.plugins.values():
|
|
teardown = plugin.write_to_descriptors(w)
|
|
if teardown:
|
|
return True
|
|
return False
|
|
|
|
def read_from_descriptors(self, r: Readables) -> bool:
|
|
for plugin in self.plugins.values():
|
|
teardown = plugin.read_from_descriptors(r)
|
|
if teardown:
|
|
return True
|
|
return False
|
|
|
|
def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
|
|
if self.switched_protocol == httpProtocolTypes.WEBSOCKET:
|
|
# TODO(abhinavsingh): Remove .tobytes after websocket frame parser
|
|
# is memoryview compliant
|
|
remaining = raw.tobytes()
|
|
frame = WebsocketFrame()
|
|
while remaining != b'':
|
|
# TODO: Teardown if invalid protocol exception
|
|
remaining = frame.parse(remaining)
|
|
if frame.opcode == websocketOpcodes.CONNECTION_CLOSE:
|
|
logger.warning(
|
|
'Client sent connection close packet',
|
|
)
|
|
raise HttpProtocolException()
|
|
else:
|
|
assert self.route
|
|
self.route.on_websocket_message(frame)
|
|
frame.reset()
|
|
return None
|
|
# If 1st valid request was completed and it's a HTTP/1.1 keep-alive
|
|
# And only if we have a route, parse pipeline requests
|
|
if self.request.state == httpParserStates.COMPLETE and \
|
|
self.request.is_http_1_1_keep_alive() and \
|
|
self.route is not None:
|
|
if self.pipeline_request is None:
|
|
self.pipeline_request = HttpParser(
|
|
httpParserTypes.REQUEST_PARSER,
|
|
)
|
|
# TODO(abhinavsingh): Remove .tobytes after parser is memoryview
|
|
# compliant
|
|
self.pipeline_request.parse(raw.tobytes())
|
|
if self.pipeline_request.state == httpParserStates.COMPLETE:
|
|
self.route.handle_request(self.pipeline_request)
|
|
if not self.pipeline_request.is_http_1_1_keep_alive():
|
|
logger.error(
|
|
'Pipelined request is not keep-alive, will teardown request...',
|
|
)
|
|
raise HttpProtocolException()
|
|
self.pipeline_request = None
|
|
return raw
|
|
|
|
def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]:
|
|
return chunk
|
|
|
|
def on_client_connection_close(self) -> None:
|
|
if self.request.has_host():
|
|
return
|
|
context = {
|
|
'client_addr': self.client.address,
|
|
'request_method': text_(self.request.method),
|
|
'request_path': text_(self.request.path),
|
|
'connection_time_ms': '%.2f' % ((time.time() - self.start_time) * 1000),
|
|
}
|
|
log_handled = False
|
|
if self.route:
|
|
# May be merge on_client_connection_close and on_access_log???
|
|
# probably by simply deprecating on_client_connection_close in future.
|
|
self.route.on_client_connection_close()
|
|
ctx = self.route.on_access_log(context)
|
|
if ctx is None:
|
|
log_handled = True
|
|
else:
|
|
context = ctx
|
|
if not log_handled:
|
|
self.access_log(context)
|
|
|
|
# TODO: Allow plugins to customize access_log, similar
|
|
# to how proxy server plugins are able to do it.
|
|
def access_log(self, context: Dict[str, Any]) -> None:
|
|
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))
|