Decouple SSL wrap logic into connection classes (#394)

* Move wrap functionality within respective connection classes. Also decouple websocket client handshake method

* Add a TCP echo client example that works with TCP echo server example
This commit is contained in:
Abhinav Singh 2020-07-08 13:11:12 +05:30 committed by GitHub
parent c884338f42
commit 682114e9e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 94 additions and 34 deletions

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from proxy.common.utils import socket_connection
from proxy.common.constants import DEFAULT_BUFFER_SIZE
if __name__ == '__main__':
with socket_connection(('::', 12345)) as client:
while True:
client.send(b'hello')
data = client.recv(DEFAULT_BUFFER_SIZE)
if data is None:
break
print(data)

View File

@ -12,7 +12,7 @@ import time
import socket
import selectors
from typing import Dict
from typing import Dict, Any
from proxy.core.acceptor import AcceptorPool, Work
from proxy.common.flags import Flags
@ -29,6 +29,10 @@ class EchoServerHandler(Work):
intialize, is_inactive and shutdown method.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
print('Connection accepted from {0}'.format(self.client.addr))
def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
# Register for EVENT_READ events
@ -45,12 +49,21 @@ class EchoServerHandler(Work):
writables: Writables) -> bool:
"""Return True to shutdown work."""
if self.client.connection in readables:
data = self.client.recv()
if data is None:
# Client closed connection, signal shutdown
try:
data = self.client.recv()
if data is None:
# Client closed connection, signal shutdown
print(
'Connection closed by client {0}'.format(
self.client.addr))
return True
# Echo data back to client
self.client.queue(data)
except ConnectionResetError:
print(
'Connection reset by client {0}'.format(
self.client.addr))
return True
# Queue data back to client
self.client.queue(data)
if self.client.connection in writables:
self.client.flush()
@ -61,7 +74,7 @@ class EchoServerHandler(Work):
def main() -> None:
# This example requires `threadless=True`
pool = AcceptorPool(
flags=Flags(num_workers=1, threadless=True),
flags=Flags(port=12345, num_workers=1, threadless=True),
work_klass=EchoServerHandler)
try:
pool.setup()

View File

@ -22,8 +22,10 @@ num_echos = 10
def on_message(frame: WebsocketFrame) -> None:
"""WebsocketClient on_message callback."""
global client, num_echos, last_dispatch_time
print('Received %r after %d millisec' % (frame.data, (time.time() - last_dispatch_time) * 1000))
assert(frame.data == b'hello' and frame.opcode == websocketOpcodes.TEXT_FRAME)
print('Received %r after %d millisec' %
(frame.data, (time.time() - last_dispatch_time) * 1000))
assert(frame.data == b'hello' and frame.opcode ==
websocketOpcodes.TEXT_FRAME)
if num_echos > 0:
client.queue(static_frame)
last_dispatch_time = time.time()
@ -34,7 +36,11 @@ def on_message(frame: WebsocketFrame) -> None:
if __name__ == '__main__':
# Constructor establishes socket connection
client = WebsocketClient(b'echo.websocket.org', 80, b'/', on_message=on_message)
client = WebsocketClient(
b'echo.websocket.org',
80,
b'/',
on_message=on_message)
# Perform handshake
client.handshake()
# Queue some data for client

View File

@ -30,3 +30,15 @@ class TcpClientConnection(TcpConnection):
if self._conn is None:
raise TcpConnectionUninitializedException()
return self._conn
def wrap(self, keyfile: str, certfile: str) -> None:
self.connection.setblocking(True)
self.flush()
self._conn = ssl.wrap_socket(
self.connection,
server_side=True,
# ca_certs=self.flags.ca_cert_file,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.PROTOCOL_TLS)
self.connection.setblocking(False)

View File

@ -8,8 +8,8 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import socket
import ssl
import socket
from typing import Optional, Union, Tuple
from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException
@ -34,3 +34,14 @@ class TcpServerConnection(TcpConnection):
if self._conn is not None:
return
self._conn = new_socket_connection(self.addr)
def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=ca_file)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
ctx.check_hostname = True
self.connection.setblocking(True)
self._conn = ctx.wrap_socket(
self.connection,
server_hostname=hostname)
self.connection.setblocking(False)

View File

@ -453,31 +453,15 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
def wrap_server(self) -> None:
assert self.server is not None
assert isinstance(self.server.connection, socket.socket)
ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
ctx.check_hostname = True
self.server.connection.setblocking(True)
self.server._conn = ctx.wrap_socket(
self.server.connection,
server_hostname=text_(self.request.host))
self.server.connection.setblocking(False)
self.server.wrap(text_(self.request.host), self.flags.ca_file)
assert isinstance(self.server.connection, ssl.SSLSocket)
def wrap_client(self) -> None:
assert self.server is not None
assert self.server is not None and self.flags.ca_signing_key_file is not None
assert isinstance(self.server.connection, ssl.SSLSocket)
generated_cert = self.generate_upstream_certificate(
cast(Dict[str, Any], self.server.connection.getpeercert()))
self.client.connection.setblocking(True)
self.client.flush()
self.client._conn = ssl.wrap_socket(
self.client.connection,
server_side=True,
# ca_certs=self.flags.ca_cert_file,
certfile=generated_cert,
keyfile=self.flags.ca_signing_key_file,
ssl_version=ssl.PROTOCOL_TLS)
self.client.connection.setblocking(False)
self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
logger.debug(
'TLS interception using %s', generated_cert)

View File

@ -52,7 +52,11 @@ class WebsocketClient(TcpConnection):
def upgrade(self) -> None:
key = base64.b64encode(secrets.token_bytes(16))
self.sock.send(build_websocket_handshake_request(key, url=self.path, host=self.hostname))
self.sock.send(
build_websocket_handshake_request(
key,
url=self.path,
host=self.hostname))
response = HttpParser(httpParserTypes.RESPONSE_PARSER)
response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE))
accept = response.header(b'Sec-Websocket-Accept')

View File

@ -17,7 +17,7 @@ import selectors
from typing import Any
from unittest import mock
from proxy.core.connection import TcpClientConnection
from proxy.core.connection import TcpClientConnection, TcpServerConnection
from proxy.http.handler import HttpProtocolHandler
from proxy.http.proxy import HttpProxyPlugin
from proxy.http.methods import httpMethods
@ -71,6 +71,11 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
return ssl_connection
return plain_connection
# Do not mock the original wrap method
self.mock_server_conn.return_value.wrap.side_effect = \
lambda x, y: TcpServerConnection.wrap(
self.mock_server_conn.return_value, x, y)
type(self.mock_server_conn.return_value).connection = \
mock.PropertyMock(side_effect=mock_connection)

View File

@ -19,7 +19,7 @@ from typing import Any, cast
from proxy.common.utils import bytes_
from proxy.common.flags import Flags
from proxy.common.utils import build_http_request, build_http_response
from proxy.core.connection import TcpClientConnection
from proxy.core.connection import TcpClientConnection, TcpServerConnection
from proxy.http.codes import httpStatusCodes
from proxy.http.methods import httpMethods
from proxy.http.handler import HttpProtocolHandler
@ -98,6 +98,10 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
return self.server_ssl_connection
return self._conn
# Do not mock the original wrap method
self.server.wrap.side_effect = \
lambda x, y: TcpServerConnection.wrap(self.server, x, y)
self.server.has_buffer.side_effect = has_buffer
type(self.server).closed = mock.PropertyMock(side_effect=closed)
type(