130 lines
3.9 KiB
Python
130 lines
3.9 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
proxy.py
|
||
|
~~~~~~~~
|
||
|
⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file.
|
||
|
|
||
|
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||
|
:license: BSD, see LICENSE for more details.
|
||
|
"""
|
||
|
import socket
|
||
|
import ssl
|
||
|
import logging
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import NamedTuple, Optional, Union, Tuple
|
||
|
|
||
|
from ..common.constants import DEFAULT_BUFFER_SIZE
|
||
|
from ..common.utils import new_socket_connection
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
TcpConnectionTypes = NamedTuple('TcpConnectionTypes', [
|
||
|
('SERVER', int),
|
||
|
('CLIENT', int),
|
||
|
])
|
||
|
tcpConnectionTypes = TcpConnectionTypes(1, 2)
|
||
|
|
||
|
|
||
|
class TcpConnectionUninitializedException(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class TcpConnection(ABC):
|
||
|
"""TCP server/client connection abstraction.
|
||
|
|
||
|
Main motivation of this class is to provide a buffer management
|
||
|
when reading and writing into the socket.
|
||
|
|
||
|
Implement the connection property abstract method to return
|
||
|
a socket connection object."""
|
||
|
|
||
|
def __init__(self, tag: int):
|
||
|
self.buffer: bytes = b''
|
||
|
self.closed: bool = False
|
||
|
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||
|
"""Must return the socket connection to use in this class."""
|
||
|
raise TcpConnectionUninitializedException() # pragma: no cover
|
||
|
|
||
|
def send(self, data: bytes) -> int:
|
||
|
"""Users must handle BrokenPipeError exceptions"""
|
||
|
return self.connection.send(data)
|
||
|
|
||
|
def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]:
|
||
|
"""Users must handle socket.error exceptions"""
|
||
|
data: bytes = self.connection.recv(buffer_size)
|
||
|
if len(data) == 0:
|
||
|
return None
|
||
|
logger.debug(
|
||
|
'received %d bytes from %s' %
|
||
|
(len(data), self.tag))
|
||
|
# logger.info(data)
|
||
|
return data
|
||
|
|
||
|
def close(self) -> bool:
|
||
|
if not self.closed:
|
||
|
self.connection.close()
|
||
|
self.closed = True
|
||
|
return self.closed
|
||
|
|
||
|
def buffer_size(self) -> int:
|
||
|
return len(self.buffer)
|
||
|
|
||
|
def has_buffer(self) -> bool:
|
||
|
return self.buffer_size() > 0
|
||
|
|
||
|
def queue(self, data: bytes) -> int:
|
||
|
self.buffer += data
|
||
|
return len(data)
|
||
|
|
||
|
def flush(self) -> int:
|
||
|
"""Users must handle BrokenPipeError exceptions"""
|
||
|
if self.buffer_size() == 0:
|
||
|
return 0
|
||
|
sent: int = self.send(self.buffer)
|
||
|
# logger.info(self.buffer[:sent])
|
||
|
self.buffer = self.buffer[sent:]
|
||
|
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
|
||
|
return sent
|
||
|
|
||
|
|
||
|
class TcpServerConnection(TcpConnection):
|
||
|
"""Establishes connection to upstream server."""
|
||
|
|
||
|
def __init__(self, host: str, port: int):
|
||
|
super().__init__(tcpConnectionTypes.SERVER)
|
||
|
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
|
||
|
self.addr: Tuple[str, int] = (host, int(port))
|
||
|
|
||
|
@property
|
||
|
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||
|
if self._conn is None:
|
||
|
raise TcpConnectionUninitializedException()
|
||
|
return self._conn
|
||
|
|
||
|
def connect(self) -> None:
|
||
|
if self._conn is not None:
|
||
|
return
|
||
|
self._conn = new_socket_connection(self.addr)
|
||
|
|
||
|
|
||
|
class TcpClientConnection(TcpConnection):
|
||
|
"""An accepted client connection request."""
|
||
|
|
||
|
def __init__(self,
|
||
|
conn: Union[ssl.SSLSocket, socket.socket],
|
||
|
addr: Tuple[str, int]):
|
||
|
super().__init__(tcpConnectionTypes.CLIENT)
|
||
|
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
|
||
|
self.addr: Tuple[str, int] = addr
|
||
|
|
||
|
@property
|
||
|
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
|
||
|
if self._conn is None:
|
||
|
raise TcpConnectionUninitializedException()
|
||
|
return self._conn
|