proxy.py/proxy/core/connection.py

130 lines
3.9 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import socket
import ssl
import logging
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Union, Tuple
from ..common.constants import DEFAULT_BUFFER_SIZE
from ..common.utils import new_socket_connection
logger = logging.getLogger(__name__)
TcpConnectionTypes = NamedTuple('TcpConnectionTypes', [
('SERVER', int),
('CLIENT', int),
])
tcpConnectionTypes = TcpConnectionTypes(1, 2)
class TcpConnectionUninitializedException(Exception):
pass
class TcpConnection(ABC):
"""TCP server/client connection abstraction.
Main motivation of this class is to provide a buffer management
when reading and writing into the socket.
Implement the connection property abstract method to return
a socket connection object."""
def __init__(self, tag: int):
self.buffer: bytes = b''
self.closed: bool = False
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
@property
@abstractmethod
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
"""Must return the socket connection to use in this class."""
raise TcpConnectionUninitializedException() # pragma: no cover
def send(self, data: bytes) -> int:
"""Users must handle BrokenPipeError exceptions"""
return self.connection.send(data)
def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]:
"""Users must handle socket.error exceptions"""
data: bytes = self.connection.recv(buffer_size)
if len(data) == 0:
return None
logger.debug(
'received %d bytes from %s' %
(len(data), self.tag))
# logger.info(data)
return data
def close(self) -> bool:
if not self.closed:
self.connection.close()
self.closed = True
return self.closed
def buffer_size(self) -> int:
return len(self.buffer)
def has_buffer(self) -> bool:
return self.buffer_size() > 0
def queue(self, data: bytes) -> int:
self.buffer += data
return len(data)
def flush(self) -> int:
"""Users must handle BrokenPipeError exceptions"""
if self.buffer_size() == 0:
return 0
sent: int = self.send(self.buffer)
# logger.info(self.buffer[:sent])
self.buffer = self.buffer[sent:]
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
return sent
class TcpServerConnection(TcpConnection):
"""Establishes connection to upstream server."""
def __init__(self, host: str, port: int):
super().__init__(tcpConnectionTypes.SERVER)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
self.addr: Tuple[str, int] = (host, int(port))
@property
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
if self._conn is None:
raise TcpConnectionUninitializedException()
return self._conn
def connect(self) -> None:
if self._conn is not None:
return
self._conn = new_socket_connection(self.addr)
class TcpClientConnection(TcpConnection):
"""An accepted client connection request."""
def __init__(self,
conn: Union[ssl.SSLSocket, socket.socket],
addr: Tuple[str, int]):
super().__init__(tcpConnectionTypes.CLIENT)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
self.addr: Tuple[str, int] = addr
@property
def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
if self._conn is None:
raise TcpConnectionUninitializedException()
return self._conn