proxy.py/tests/test_acceptor.py

148 lines
4.8 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Programmable Proxy Server in a single Python file.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import unittest
import socket
import selectors
import multiprocessing
from unittest import mock
from proxy.common.flags import Flags
from proxy.core.acceptor import Acceptor, AcceptorPool
class TestAcceptor(unittest.TestCase):
def setUp(self) -> None:
self.acceptor_id = 1
self.mock_protocol_handler = mock.MagicMock()
self.pipe = multiprocessing.Pipe()
self.flags = Flags()
self.acceptor = Acceptor(
idd=self.acceptor_id,
work_queue=self.pipe[1],
flags=self.flags,
work_klass=self.mock_protocol_handler)
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.core.acceptor.recv_handle')
def test_continues_when_no_events(
self,
mock_recv_handle: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
fileno = 10
conn = mock.MagicMock()
addr = mock.MagicMock()
sock = mock_fromfd.return_value
mock_fromfd.return_value.accept.return_value = (conn, addr)
mock_recv_handle.return_value = fileno
selector = mock_selector.return_value
selector.select.side_effect = [[], KeyboardInterrupt()]
self.acceptor.run()
sock.accept.assert_not_called()
self.mock_protocol_handler.assert_not_called()
@mock.patch('threading.Thread')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.core.acceptor.recv_handle')
def test_accepts_client_from_server_socket(
self,
mock_recv_handle: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock,
mock_thread: mock.Mock) -> None:
fileno = 10
conn = mock.MagicMock()
addr = mock.MagicMock()
sock = mock_fromfd.return_value
mock_fromfd.return_value.accept.return_value = (conn, addr)
mock_recv_handle.return_value = fileno
mock_thread.return_value.start.side_effect = KeyboardInterrupt()
selector = mock_selector.return_value
selector.select.return_value = [(None, None)]
self.acceptor.run()
selector.register.assert_called_with(sock, selectors.EVENT_READ)
selector.unregister.assert_called_with(sock)
mock_recv_handle.assert_called_with(self.pipe[1])
mock_fromfd.assert_called_with(
fileno,
family=socket.AF_INET6,
type=socket.SOCK_STREAM
)
self.mock_protocol_handler.assert_called_with(
fileno=conn.fileno(),
addr=addr,
flags=self.flags,
event_queue=None,
)
mock_thread.assert_called_with(
target=self.mock_protocol_handler.return_value.run)
mock_thread.return_value.start.assert_called()
sock.close.assert_called()
class TestAcceptorPool(unittest.TestCase):
@mock.patch('proxy.core.acceptor.send_handle')
@mock.patch('multiprocessing.Pipe')
@mock.patch('socket.socket')
@mock.patch('proxy.core.acceptor.Acceptor')
def test_setup_and_shutdown(
self,
mock_worker: mock.Mock,
mock_socket: mock.Mock,
mock_pipe: mock.Mock,
_mock_send_handle: mock.Mock) -> None:
mock_worker1 = mock.MagicMock()
mock_worker2 = mock.MagicMock()
mock_worker.side_effect = [mock_worker1, mock_worker2]
num_workers = 2
sock = mock_socket.return_value
work_klass = mock.MagicMock()
flags = Flags(num_workers=2)
acceptor = AcceptorPool(flags=flags, work_klass=work_klass)
acceptor.setup()
work_klass.assert_not_called()
mock_socket.assert_called_with(
socket.AF_INET6 if acceptor.flags.hostname.version == 6 else socket.AF_INET,
socket.SOCK_STREAM
)
sock.setsockopt.assert_called_with(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind.assert_called_with(
(str(acceptor.flags.hostname), acceptor.flags.port))
sock.listen.assert_called_with(acceptor.flags.backlog)
sock.setblocking.assert_called_with(False)
self.assertTrue(mock_pipe.call_count, num_workers)
self.assertTrue(mock_worker.call_count, num_workers)
mock_worker1.start.assert_called()
mock_worker1.join.assert_not_called()
mock_worker2.start.assert_called()
mock_worker2.join.assert_not_called()
sock.close.assert_called()
acceptor.shutdown()
mock_worker1.join.assert_called()
mock_worker2.join.assert_called()