proxy.py/proxy/http/websocket.py

267 lines
8.4 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import hashlib
import base64
import selectors
import struct
import socket
import secrets
import ssl
import ipaddress
import logging
import io
from typing import TypeVar, Type, Optional, NamedTuple, Union, Callable
from .parser import httpParserTypes, HttpParser
from ..common.constants import DEFAULT_BUFFER_SIZE
from ..common.utils import new_socket_connection, build_websocket_handshake_request
from ..core.connection import tcpConnectionTypes, TcpConnection
WebsocketOpcodes = NamedTuple('WebsocketOpcodes', [
('CONTINUATION_FRAME', int),
('TEXT_FRAME', int),
('BINARY_FRAME', int),
('CONNECTION_CLOSE', int),
('PING', int),
('PONG', int),
])
websocketOpcodes = WebsocketOpcodes(0x0, 0x1, 0x2, 0x8, 0x9, 0xA)
V = TypeVar('V', bound='WebsocketFrame')
logger = logging.getLogger(__name__)
class WebsocketFrame:
"""Websocket frames parser and constructor."""
GUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
def __init__(self) -> None:
self.fin: bool = False
self.rsv1: bool = False
self.rsv2: bool = False
self.rsv3: bool = False
self.opcode: int = 0
self.masked: bool = False
self.payload_length: Optional[int] = None
self.mask: Optional[bytes] = None
self.data: Optional[bytes] = None
@classmethod
def text(cls: Type[V], data: bytes) -> bytes:
frame = cls()
frame.fin = True
frame.opcode = websocketOpcodes.TEXT_FRAME
frame.data = data
return frame.build()
def reset(self) -> None:
self.fin = False
self.rsv1 = False
self.rsv2 = False
self.rsv3 = False
self.opcode = 0
self.masked = False
self.payload_length = None
self.mask = None
self.data = None
def parse_fin_and_rsv(self, byte: int) -> None:
self.fin = bool(byte & 1 << 7)
self.rsv1 = bool(byte & 1 << 6)
self.rsv2 = bool(byte & 1 << 5)
self.rsv3 = bool(byte & 1 << 4)
self.opcode = byte & 0b00001111
def parse_mask_and_payload(self, byte: int) -> None:
self.masked = bool(byte & 0b10000000)
self.payload_length = byte & 0b01111111
def build(self) -> bytes:
if self.payload_length is None and self.data:
self.payload_length = len(self.data)
raw = io.BytesIO()
raw.write(
struct.pack(
'!B',
(1 << 7 if self.fin else 0) |
(1 << 6 if self.rsv1 else 0) |
(1 << 5 if self.rsv2 else 0) |
(1 << 4 if self.rsv3 else 0) |
self.opcode
))
assert self.payload_length is not None
if self.payload_length < 126:
raw.write(
struct.pack(
'!B',
(1 << 7 if self.masked else 0) | self.payload_length
)
)
elif self.payload_length < 1 << 16:
raw.write(
struct.pack(
'!BH',
(1 << 7 if self.masked else 0) | 126,
self.payload_length
)
)
elif self.payload_length < 1 << 64:
raw.write(
struct.pack(
'!BHQ',
(1 << 7 if self.masked else 0) | 127,
self.payload_length
)
)
else:
raise ValueError(f'Invalid payload_length { self.payload_length },'
f'maximum allowed { 1 << 64 }')
if self.masked and self.data:
mask = secrets.token_bytes(4) if self.mask is None else self.mask
raw.write(mask)
raw.write(self.apply_mask(self.data, mask))
elif self.data:
raw.write(self.data)
return raw.getvalue()
def parse(self, raw: bytes) -> bytes:
cur = 0
self.parse_fin_and_rsv(raw[cur])
cur += 1
self.parse_mask_and_payload(raw[cur])
cur += 1
if self.payload_length == 126:
data = raw[cur: cur + 2]
self.payload_length, = struct.unpack('!H', data)
cur += 2
elif self.payload_length == 127:
data = raw[cur: cur + 8]
self.payload_length, = struct.unpack('!Q', data)
cur += 8
if self.masked:
self.mask = raw[cur: cur + 4]
cur += 4
assert self.payload_length
self.data = raw[cur: cur + self.payload_length]
cur += self.payload_length
if self.masked:
assert self.mask is not None
self.data = self.apply_mask(self.data, self.mask)
return raw[cur:]
@staticmethod
def apply_mask(data: bytes, mask: bytes) -> bytes:
raw = bytearray(data)
for i in range(len(raw)):
raw[i] = raw[i] ^ mask[i % 4]
return bytes(raw)
@staticmethod
def key_to_accept(key: bytes) -> bytes:
sha1 = hashlib.sha1()
sha1.update(key + WebsocketFrame.GUID)
return base64.b64encode(sha1.digest())
class WebsocketClient(TcpConnection):
def __init__(self,
hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address],
port: int,
path: bytes = b'/',
on_message: Optional[Callable[[WebsocketFrame], None]] = None) -> None:
super().__init__(tcpConnectionTypes.CLIENT)
self.hostname: Union[ipaddress.IPv4Address,
ipaddress.IPv6Address] = hostname
self.port: int = port
self.path: bytes = path
self.sock: socket.socket = new_socket_connection(
(str(self.hostname), self.port))
self.on_message: Optional[Callable[[
WebsocketFrame], None]] = on_message
self.upgrade()
self.sock.setblocking(False)
self.selector: selectors.DefaultSelector = selectors.DefaultSelector()
@property
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
return self.sock
def upgrade(self) -> None:
key = base64.b64encode(secrets.token_bytes(16))
self.sock.send(build_websocket_handshake_request(key, url=self.path))
response = HttpParser(httpParserTypes.RESPONSE_PARSER)
response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE))
accept = response.header(b'Sec-Websocket-Accept')
assert WebsocketFrame.key_to_accept(key) == accept
def ping(self, data: Optional[bytes] = None) -> None:
pass
def pong(self, data: Optional[bytes] = None) -> None:
pass
def shutdown(self, _data: Optional[bytes] = None) -> None:
"""Closes connection with the server."""
super().close()
def run_once(self) -> bool:
ev = selectors.EVENT_READ
if self.has_buffer():
ev |= selectors.EVENT_WRITE
self.selector.register(self.sock.fileno(), ev)
events = self.selector.select(timeout=1)
self.selector.unregister(self.sock)
for key, mask in events:
if mask & selectors.EVENT_READ and self.on_message:
raw = self.recv()
if raw is None or raw == b'':
self.closed = True
logger.debug('Websocket connection closed by server')
return True
frame = WebsocketFrame()
frame.parse(raw)
self.on_message(frame)
elif mask & selectors.EVENT_WRITE:
logger.debug(self.buffer)
self.flush()
return False
def run(self) -> None:
logger.debug('running')
try:
while not self.closed:
teardown = self.run_once()
if teardown:
break
except KeyboardInterrupt:
pass
finally:
try:
self.selector.unregister(self.sock)
self.sock.shutdown(socket.SHUT_WR)
except Exception as e:
logging.exception(
'Exception while shutdown of websocket client', exc_info=e)
self.sock.close()
logger.info('done')