Decouple SSL wrap logic into connection classes (#394)
* Move wrap functionality within respective connection classes. Also decouple websocket client handshake method * Add a TCP echo client example that works with TCP echo server example
This commit is contained in:
parent
c884338f42
commit
682114e9e0
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
proxy.py
|
||||
~~~~~~~~
|
||||
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
|
||||
Network monitoring, controls & Application development, testing, debugging.
|
||||
|
||||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
from proxy.common.utils import socket_connection
|
||||
from proxy.common.constants import DEFAULT_BUFFER_SIZE
|
||||
|
||||
if __name__ == '__main__':
|
||||
with socket_connection(('::', 12345)) as client:
|
||||
while True:
|
||||
client.send(b'hello')
|
||||
data = client.recv(DEFAULT_BUFFER_SIZE)
|
||||
if data is None:
|
||||
break
|
||||
print(data)
|
|
@ -12,7 +12,7 @@ import time
|
|||
import socket
|
||||
import selectors
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Any
|
||||
|
||||
from proxy.core.acceptor import AcceptorPool, Work
|
||||
from proxy.common.flags import Flags
|
||||
|
@ -29,6 +29,10 @@ class EchoServerHandler(Work):
|
|||
intialize, is_inactive and shutdown method.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
print('Connection accepted from {0}'.format(self.client.addr))
|
||||
|
||||
def get_events(self) -> Dict[socket.socket, int]:
|
||||
# We always want to read from client
|
||||
# Register for EVENT_READ events
|
||||
|
@ -45,12 +49,21 @@ class EchoServerHandler(Work):
|
|||
writables: Writables) -> bool:
|
||||
"""Return True to shutdown work."""
|
||||
if self.client.connection in readables:
|
||||
data = self.client.recv()
|
||||
if data is None:
|
||||
# Client closed connection, signal shutdown
|
||||
try:
|
||||
data = self.client.recv()
|
||||
if data is None:
|
||||
# Client closed connection, signal shutdown
|
||||
print(
|
||||
'Connection closed by client {0}'.format(
|
||||
self.client.addr))
|
||||
return True
|
||||
# Echo data back to client
|
||||
self.client.queue(data)
|
||||
except ConnectionResetError:
|
||||
print(
|
||||
'Connection reset by client {0}'.format(
|
||||
self.client.addr))
|
||||
return True
|
||||
# Queue data back to client
|
||||
self.client.queue(data)
|
||||
|
||||
if self.client.connection in writables:
|
||||
self.client.flush()
|
||||
|
@ -61,7 +74,7 @@ class EchoServerHandler(Work):
|
|||
def main() -> None:
|
||||
# This example requires `threadless=True`
|
||||
pool = AcceptorPool(
|
||||
flags=Flags(num_workers=1, threadless=True),
|
||||
flags=Flags(port=12345, num_workers=1, threadless=True),
|
||||
work_klass=EchoServerHandler)
|
||||
try:
|
||||
pool.setup()
|
||||
|
|
|
@ -22,8 +22,10 @@ num_echos = 10
|
|||
def on_message(frame: WebsocketFrame) -> None:
|
||||
"""WebsocketClient on_message callback."""
|
||||
global client, num_echos, last_dispatch_time
|
||||
print('Received %r after %d millisec' % (frame.data, (time.time() - last_dispatch_time) * 1000))
|
||||
assert(frame.data == b'hello' and frame.opcode == websocketOpcodes.TEXT_FRAME)
|
||||
print('Received %r after %d millisec' %
|
||||
(frame.data, (time.time() - last_dispatch_time) * 1000))
|
||||
assert(frame.data == b'hello' and frame.opcode ==
|
||||
websocketOpcodes.TEXT_FRAME)
|
||||
if num_echos > 0:
|
||||
client.queue(static_frame)
|
||||
last_dispatch_time = time.time()
|
||||
|
@ -34,7 +36,11 @@ def on_message(frame: WebsocketFrame) -> None:
|
|||
|
||||
if __name__ == '__main__':
|
||||
# Constructor establishes socket connection
|
||||
client = WebsocketClient(b'echo.websocket.org', 80, b'/', on_message=on_message)
|
||||
client = WebsocketClient(
|
||||
b'echo.websocket.org',
|
||||
80,
|
||||
b'/',
|
||||
on_message=on_message)
|
||||
# Perform handshake
|
||||
client.handshake()
|
||||
# Queue some data for client
|
||||
|
|
|
@ -30,3 +30,15 @@ class TcpClientConnection(TcpConnection):
|
|||
if self._conn is None:
|
||||
raise TcpConnectionUninitializedException()
|
||||
return self._conn
|
||||
|
||||
def wrap(self, keyfile: str, certfile: str) -> None:
|
||||
self.connection.setblocking(True)
|
||||
self.flush()
|
||||
self._conn = ssl.wrap_socket(
|
||||
self.connection,
|
||||
server_side=True,
|
||||
# ca_certs=self.flags.ca_cert_file,
|
||||
certfile=certfile,
|
||||
keyfile=keyfile,
|
||||
ssl_version=ssl.PROTOCOL_TLS)
|
||||
self.connection.setblocking(False)
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
:copyright: (c) 2013-present by Abhinav Singh and contributors.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import socket
|
||||
import ssl
|
||||
import socket
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException
|
||||
|
@ -34,3 +34,14 @@ class TcpServerConnection(TcpConnection):
|
|||
if self._conn is not None:
|
||||
return
|
||||
self._conn = new_socket_connection(self.addr)
|
||||
|
||||
def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
|
||||
ctx = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH, cafile=ca_file)
|
||||
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
|
||||
ctx.check_hostname = True
|
||||
self.connection.setblocking(True)
|
||||
self._conn = ctx.wrap_socket(
|
||||
self.connection,
|
||||
server_hostname=hostname)
|
||||
self.connection.setblocking(False)
|
||||
|
|
|
@ -453,31 +453,15 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
def wrap_server(self) -> None:
|
||||
assert self.server is not None
|
||||
assert isinstance(self.server.connection, socket.socket)
|
||||
ctx = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file)
|
||||
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
|
||||
ctx.check_hostname = True
|
||||
self.server.connection.setblocking(True)
|
||||
self.server._conn = ctx.wrap_socket(
|
||||
self.server.connection,
|
||||
server_hostname=text_(self.request.host))
|
||||
self.server.connection.setblocking(False)
|
||||
self.server.wrap(text_(self.request.host), self.flags.ca_file)
|
||||
assert isinstance(self.server.connection, ssl.SSLSocket)
|
||||
|
||||
def wrap_client(self) -> None:
|
||||
assert self.server is not None
|
||||
assert self.server is not None and self.flags.ca_signing_key_file is not None
|
||||
assert isinstance(self.server.connection, ssl.SSLSocket)
|
||||
generated_cert = self.generate_upstream_certificate(
|
||||
cast(Dict[str, Any], self.server.connection.getpeercert()))
|
||||
self.client.connection.setblocking(True)
|
||||
self.client.flush()
|
||||
self.client._conn = ssl.wrap_socket(
|
||||
self.client.connection,
|
||||
server_side=True,
|
||||
# ca_certs=self.flags.ca_cert_file,
|
||||
certfile=generated_cert,
|
||||
keyfile=self.flags.ca_signing_key_file,
|
||||
ssl_version=ssl.PROTOCOL_TLS)
|
||||
self.client.connection.setblocking(False)
|
||||
self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
|
||||
logger.debug(
|
||||
'TLS interception using %s', generated_cert)
|
||||
|
||||
|
|
|
@ -52,7 +52,11 @@ class WebsocketClient(TcpConnection):
|
|||
|
||||
def upgrade(self) -> None:
|
||||
key = base64.b64encode(secrets.token_bytes(16))
|
||||
self.sock.send(build_websocket_handshake_request(key, url=self.path, host=self.hostname))
|
||||
self.sock.send(
|
||||
build_websocket_handshake_request(
|
||||
key,
|
||||
url=self.path,
|
||||
host=self.hostname))
|
||||
response = HttpParser(httpParserTypes.RESPONSE_PARSER)
|
||||
response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE))
|
||||
accept = response.header(b'Sec-Websocket-Accept')
|
||||
|
|
|
@ -17,7 +17,7 @@ import selectors
|
|||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.core.connection import TcpClientConnection, TcpServerConnection
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
from proxy.http.proxy import HttpProxyPlugin
|
||||
from proxy.http.methods import httpMethods
|
||||
|
@ -71,6 +71,11 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
|
|||
return ssl_connection
|
||||
return plain_connection
|
||||
|
||||
# Do not mock the original wrap method
|
||||
self.mock_server_conn.return_value.wrap.side_effect = \
|
||||
lambda x, y: TcpServerConnection.wrap(
|
||||
self.mock_server_conn.return_value, x, y)
|
||||
|
||||
type(self.mock_server_conn.return_value).connection = \
|
||||
mock.PropertyMock(side_effect=mock_connection)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Any, cast
|
|||
from proxy.common.utils import bytes_
|
||||
from proxy.common.flags import Flags
|
||||
from proxy.common.utils import build_http_request, build_http_response
|
||||
from proxy.core.connection import TcpClientConnection
|
||||
from proxy.core.connection import TcpClientConnection, TcpServerConnection
|
||||
from proxy.http.codes import httpStatusCodes
|
||||
from proxy.http.methods import httpMethods
|
||||
from proxy.http.handler import HttpProtocolHandler
|
||||
|
@ -98,6 +98,10 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
|
|||
return self.server_ssl_connection
|
||||
return self._conn
|
||||
|
||||
# Do not mock the original wrap method
|
||||
self.server.wrap.side_effect = \
|
||||
lambda x, y: TcpServerConnection.wrap(self.server, x, y)
|
||||
|
||||
self.server.has_buffer.side_effect = has_buffer
|
||||
type(self.server).closed = mock.PropertyMock(side_effect=closed)
|
||||
type(
|
||||
|
|
Loading…
Reference in New Issue