321 lines
12 KiB
Python
321 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
proxy.py
|
|
~~~~~~~~
|
|
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
|
Network monitoring, controls & Application development, testing, debugging.
|
|
|
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
import gzip
|
|
import time
|
|
import logging
|
|
import os
|
|
import mimetypes
|
|
import socket
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
|
|
|
|
from .exception import HttpProtocolException
|
|
from .websocket import WebsocketFrame, websocketOpcodes
|
|
from .codes import httpStatusCodes
|
|
from .parser import HttpParser, httpParserStates, httpParserTypes
|
|
from .handler import HttpProtocolHandlerPlugin
|
|
|
|
from ..common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response
|
|
from ..common.flags import Flags
|
|
from ..common.constants import PROXY_AGENT_HEADER_VALUE
|
|
from ..common.types import HasFileno
|
|
from ..core.connection import TcpClientConnection
|
|
from ..core.event import EventQueue
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
HttpProtocolTypes = NamedTuple('HttpProtocolTypes', [
|
|
('HTTP', int),
|
|
('HTTPS', int),
|
|
('WEBSOCKET', int),
|
|
])
|
|
httpProtocolTypes = HttpProtocolTypes(1, 2, 3)
|
|
|
|
|
|
class HttpWebServerBasePlugin(ABC):
|
|
"""Web Server Plugin for routing of requests."""
|
|
|
|
def __init__(
|
|
self,
|
|
uid: str,
|
|
flags: Flags,
|
|
client: TcpClientConnection,
|
|
event_queue: EventQueue):
|
|
self.uid = uid
|
|
self.flags = flags
|
|
self.client = client
|
|
self.event_queue = event_queue
|
|
|
|
@abstractmethod
|
|
def routes(self) -> List[Tuple[int, bytes]]:
|
|
"""Return List(protocol, path) that this plugin handles."""
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def handle_request(self, request: HttpParser) -> None:
|
|
"""Handle the request and serve response."""
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def on_websocket_open(self) -> None:
|
|
"""Called when websocket handshake has finished."""
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def on_websocket_message(self, frame: WebsocketFrame) -> None:
|
|
"""Handle websocket frame."""
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def on_websocket_close(self) -> None:
|
|
"""Called when websocket connection has been closed."""
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
|
|
class HttpWebServerPacFilePlugin(HttpWebServerBasePlugin):
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.pac_file_response: Optional[bytes] = None
|
|
self.cache_pac_file_response()
|
|
|
|
def routes(self) -> List[Tuple[int, bytes]]:
|
|
if self.flags.pac_file_url_path:
|
|
return [
|
|
(httpProtocolTypes.HTTP, bytes_(self.flags.pac_file_url_path)),
|
|
(httpProtocolTypes.HTTPS, bytes_(self.flags.pac_file_url_path)),
|
|
]
|
|
return [] # pragma: no cover
|
|
|
|
def handle_request(self, request: HttpParser) -> None:
|
|
if self.flags.pac_file and self.pac_file_response:
|
|
self.client.queue(self.pac_file_response)
|
|
|
|
def on_websocket_open(self) -> None:
|
|
pass # pragma: no cover
|
|
|
|
def on_websocket_message(self, frame: WebsocketFrame) -> None:
|
|
pass # pragma: no cover
|
|
|
|
def on_websocket_close(self) -> None:
|
|
pass # pragma: no cover
|
|
|
|
def cache_pac_file_response(self) -> None:
|
|
if self.flags.pac_file:
|
|
try:
|
|
with open(self.flags.pac_file, 'rb') as f:
|
|
content = f.read()
|
|
except IOError:
|
|
content = bytes_(self.flags.pac_file)
|
|
self.pac_file_response = build_http_response(
|
|
200, reason=b'OK', headers={
|
|
b'Content-Type': b'application/x-ns-proxy-autoconfig',
|
|
b'Content-Encoding': b'gzip',
|
|
}, body=gzip.compress(content)
|
|
)
|
|
|
|
|
|
class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
|
|
"""HttpProtocolHandler plugin which handles incoming requests to local web server."""
|
|
|
|
DEFAULT_404_RESPONSE = build_http_response(
|
|
httpStatusCodes.NOT_FOUND,
|
|
reason=b'NOT FOUND',
|
|
headers={b'Server': PROXY_AGENT_HEADER_VALUE,
|
|
b'Connection': b'close'}
|
|
)
|
|
|
|
DEFAULT_501_RESPONSE = build_http_response(
|
|
httpStatusCodes.NOT_IMPLEMENTED,
|
|
reason=b'NOT IMPLEMENTED',
|
|
headers={b'Server': PROXY_AGENT_HEADER_VALUE,
|
|
b'Connection': b'close'}
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.start_time: float = time.time()
|
|
self.pipeline_request: Optional[HttpParser] = None
|
|
self.switched_protocol: Optional[int] = None
|
|
self.routes: Dict[int, Dict[bytes, HttpWebServerBasePlugin]] = {
|
|
httpProtocolTypes.HTTP: {},
|
|
httpProtocolTypes.HTTPS: {},
|
|
httpProtocolTypes.WEBSOCKET: {},
|
|
}
|
|
self.route: Optional[HttpWebServerBasePlugin] = None
|
|
|
|
if b'HttpWebServerBasePlugin' in self.flags.plugins:
|
|
for klass in self.flags.plugins[b'HttpWebServerBasePlugin']:
|
|
instance = klass(
|
|
self.uid,
|
|
self.flags,
|
|
self.client,
|
|
self.event_queue)
|
|
for (protocol, path) in instance.routes():
|
|
self.routes[protocol][path] = instance
|
|
|
|
@staticmethod
|
|
def read_and_build_static_file_response(path: str) -> bytes:
|
|
with open(path, 'rb') as f:
|
|
content = f.read()
|
|
content_type = mimetypes.guess_type(path)[0]
|
|
if content_type is None:
|
|
content_type = 'text/plain'
|
|
return build_http_response(
|
|
httpStatusCodes.OK,
|
|
reason=b'OK',
|
|
headers={
|
|
b'Content-Type': bytes_(content_type),
|
|
b'Cache-Control': b'max-age=86400',
|
|
b'Content-Encoding': b'gzip',
|
|
b'Connection': b'close',
|
|
},
|
|
body=gzip.compress(content))
|
|
|
|
def serve_file_or_404(self, path: str) -> bool:
|
|
"""Read and serves a file from disk.
|
|
|
|
Queues 404 Not Found for IOError.
|
|
Shouldn't this be server error?
|
|
"""
|
|
try:
|
|
self.client.queue(
|
|
self.read_and_build_static_file_response(path))
|
|
except IOError:
|
|
self.client.queue(self.DEFAULT_404_RESPONSE)
|
|
return True
|
|
|
|
def try_upgrade(self) -> bool:
|
|
if self.request.has_header(b'connection') and \
|
|
self.request.header(b'connection').lower() == b'upgrade':
|
|
if self.request.has_header(b'upgrade') and \
|
|
self.request.header(b'upgrade').lower() == b'websocket':
|
|
self.client.queue(
|
|
build_websocket_handshake_response(
|
|
WebsocketFrame.key_to_accept(
|
|
self.request.header(b'Sec-WebSocket-Key'))))
|
|
self.switched_protocol = httpProtocolTypes.WEBSOCKET
|
|
else:
|
|
self.client.queue(self.DEFAULT_501_RESPONSE)
|
|
return True
|
|
return False
|
|
|
|
def on_request_complete(self) -> Union[socket.socket, bool]:
|
|
if self.request.has_upstream_server():
|
|
return False
|
|
|
|
# If a websocket route exists for the path, try upgrade
|
|
if self.request.path in self.routes[httpProtocolTypes.WEBSOCKET]:
|
|
self.route = self.routes[httpProtocolTypes.WEBSOCKET][self.request.path]
|
|
|
|
# Connection upgrade
|
|
teardown = self.try_upgrade()
|
|
if teardown:
|
|
return True
|
|
|
|
# For upgraded connections, nothing more to do
|
|
if self.switched_protocol:
|
|
# Invoke plugin.on_websocket_open
|
|
self.route.on_websocket_open()
|
|
return False
|
|
|
|
# Routing for Http(s) requests
|
|
protocol = httpProtocolTypes.HTTPS \
|
|
if self.flags.encryption_enabled() else \
|
|
httpProtocolTypes.HTTP
|
|
for r in self.routes[protocol]:
|
|
if r == self.request.path:
|
|
self.route = self.routes[protocol][r]
|
|
self.route.handle_request(self.request)
|
|
return False
|
|
|
|
# No-route found, try static serving if enabled
|
|
if self.flags.enable_static_server:
|
|
path = text_(self.request.path).split('?')[0]
|
|
if os.path.isfile(self.flags.static_server_dir + path):
|
|
return self.serve_file_or_404(
|
|
self.flags.static_server_dir + path)
|
|
|
|
# Catch all unhandled web server requests, return 404
|
|
self.client.queue(self.DEFAULT_404_RESPONSE)
|
|
return True
|
|
|
|
def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool:
|
|
pass
|
|
|
|
def read_from_descriptors(self, r: List[Union[int, HasFileno]]) -> bool:
|
|
pass
|
|
|
|
def on_client_data(self, raw: bytes) -> Optional[bytes]:
|
|
if self.switched_protocol == httpProtocolTypes.WEBSOCKET:
|
|
remaining = raw
|
|
frame = WebsocketFrame()
|
|
while remaining != b'':
|
|
# TODO: Teardown if invalid protocol exception
|
|
remaining = frame.parse(remaining)
|
|
for r in self.routes[httpProtocolTypes.WEBSOCKET]:
|
|
if r == self.request.path:
|
|
route = self.routes[httpProtocolTypes.WEBSOCKET][r]
|
|
if frame.opcode == websocketOpcodes.CONNECTION_CLOSE:
|
|
logger.warning(
|
|
'Client sent connection close packet')
|
|
raise HttpProtocolException()
|
|
else:
|
|
route.on_websocket_message(frame)
|
|
frame.reset()
|
|
return None
|
|
# If 1st valid request was completed and it's a HTTP/1.1 keep-alive
|
|
# And only if we have a route, parse pipeline requests
|
|
elif self.request.state == httpParserStates.COMPLETE and \
|
|
self.request.is_http_1_1_keep_alive() and \
|
|
self.route is not None:
|
|
if self.pipeline_request is None:
|
|
self.pipeline_request = HttpParser(
|
|
httpParserTypes.REQUEST_PARSER)
|
|
self.pipeline_request.parse(raw)
|
|
if self.pipeline_request.state == httpParserStates.COMPLETE:
|
|
self.route.handle_request(self.pipeline_request)
|
|
if not self.pipeline_request.is_http_1_1_keep_alive():
|
|
logger.error(
|
|
'Pipelined request is not keep-alive, will teardown request...')
|
|
raise HttpProtocolException()
|
|
self.pipeline_request = None
|
|
return raw
|
|
|
|
def on_response_chunk(self, chunk: bytes) -> bytes:
|
|
return chunk
|
|
|
|
def on_client_connection_close(self) -> None:
|
|
if self.request.has_upstream_server():
|
|
return
|
|
if self.switched_protocol:
|
|
# Invoke plugin.on_websocket_close
|
|
for r in self.routes[httpProtocolTypes.WEBSOCKET]:
|
|
if r == self.request.path:
|
|
self.routes[httpProtocolTypes.WEBSOCKET][r].on_websocket_close()
|
|
self.access_log()
|
|
|
|
def access_log(self) -> None:
|
|
logger.info(
|
|
'%s:%s - %s %s - %.2f ms' %
|
|
(self.client.addr[0],
|
|
self.client.addr[1],
|
|
text_(self.request.method),
|
|
text_(self.request.path),
|
|
(time.time() - self.start_time) * 1000))
|
|
|
|
def get_descriptors(
|
|
self) -> Tuple[List[socket.socket], List[socket.socket]]:
|
|
return [], []
|