proxy.py/tests/test_main.py

207 lines
8.2 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import unittest
import logging
import tempfile
import os
from unittest import mock
from typing import List
from proxy.proxy import main
from proxy.common.flags import Flags
from proxy.common.utils import bytes_
from proxy.http.handler import HttpProtocolHandler
from proxy.common.constants import DEFAULT_LOG_LEVEL, DEFAULT_LOG_FILE, DEFAULT_LOG_FORMAT, DEFAULT_BASIC_AUTH
from proxy.common.constants import DEFAULT_TIMEOUT, DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HTTP_PROXY
from proxy.common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_EVENTS, DEFAULT_ENABLE_DEVTOOLS
from proxy.common.constants import DEFAULT_ENABLE_WEB_SERVER, DEFAULT_THREADLESS, DEFAULT_CERT_FILE, DEFAULT_KEY_FILE
from proxy.common.constants import DEFAULT_CA_CERT_FILE, DEFAULT_CA_KEY_FILE, DEFAULT_CA_SIGNING_KEY_FILE
from proxy.common.constants import DEFAULT_PAC_FILE, DEFAULT_PLUGINS, DEFAULT_PID_FILE, DEFAULT_PORT
from proxy.common.constants import DEFAULT_NUM_WORKERS, DEFAULT_OPEN_FILE_LIMIT, DEFAULT_IPV6_HOSTNAME
from proxy.common.constants import DEFAULT_SERVER_RECVBUF_SIZE, DEFAULT_CLIENT_RECVBUF_SIZE
from proxy.common.version import __version__
def get_temp_file(name: str) -> str:
return os.path.join(tempfile.gettempdir(), name)
class TestMain(unittest.TestCase):
@staticmethod
def mock_default_args(mock_args: mock.Mock) -> None:
mock_args.version = False
mock_args.cert_file = DEFAULT_CERT_FILE
mock_args.key_file = DEFAULT_KEY_FILE
mock_args.ca_key_file = DEFAULT_CA_KEY_FILE
mock_args.ca_cert_file = DEFAULT_CA_CERT_FILE
mock_args.ca_signing_key_file = DEFAULT_CA_SIGNING_KEY_FILE
mock_args.pid_file = DEFAULT_PID_FILE
mock_args.log_file = DEFAULT_LOG_FILE
mock_args.log_level = DEFAULT_LOG_LEVEL
mock_args.log_format = DEFAULT_LOG_FORMAT
mock_args.basic_auth = DEFAULT_BASIC_AUTH
mock_args.hostname = DEFAULT_IPV6_HOSTNAME
mock_args.port = DEFAULT_PORT
mock_args.num_workers = DEFAULT_NUM_WORKERS
mock_args.disable_http_proxy = DEFAULT_DISABLE_HTTP_PROXY
mock_args.enable_web_server = DEFAULT_ENABLE_WEB_SERVER
mock_args.pac_file = DEFAULT_PAC_FILE
mock_args.plugins = DEFAULT_PLUGINS
mock_args.server_recvbuf_size = DEFAULT_SERVER_RECVBUF_SIZE
mock_args.client_recvbuf_size = DEFAULT_CLIENT_RECVBUF_SIZE
mock_args.open_file_limit = DEFAULT_OPEN_FILE_LIMIT
mock_args.enable_static_server = DEFAULT_ENABLE_STATIC_SERVER
mock_args.enable_devtools = DEFAULT_ENABLE_DEVTOOLS
mock_args.devtools_event_queue = None
mock_args.devtools_ws_path = DEFAULT_DEVTOOLS_WS_PATH
mock_args.timeout = DEFAULT_TIMEOUT
mock_args.threadless = DEFAULT_THREADLESS
mock_args.enable_events = DEFAULT_ENABLE_EVENTS
@mock.patch('time.sleep')
@mock.patch('proxy.proxy.Flags')
@mock.patch('proxy.proxy.AcceptorPool')
@mock.patch('logging.basicConfig')
def test_init_with_no_arguments(
self,
mock_logging_config: mock.Mock,
mock_acceptor_pool: mock.Mock,
mock_flags: mock.Mock,
mock_sleep: mock.Mock) -> None:
mock_sleep.side_effect = KeyboardInterrupt()
input_args: List[str] = []
flags = Flags.initialize(input_args=input_args)
mock_flags.initialize = lambda *args, **kwargs: flags
main()
mock_logging_config.assert_called_with(
level=logging.INFO,
format=DEFAULT_LOG_FORMAT
)
mock_acceptor_pool.assert_called_with(
flags=flags,
work_klass=HttpProtocolHandler,
)
mock_acceptor_pool.return_value.setup.assert_called()
mock_acceptor_pool.return_value.shutdown.assert_called()
mock_sleep.assert_called()
@mock.patch('time.sleep')
@mock.patch('os.remove')
@mock.patch('os.path.exists')
@mock.patch('builtins.open')
@mock.patch('proxy.proxy.Flags.init_parser')
@mock.patch('proxy.proxy.AcceptorPool')
def test_pid_file_is_written_and_removed(
self,
mock_acceptor_pool: mock.Mock,
mock_init_parser: mock.Mock,
mock_open: mock.Mock,
mock_exists: mock.Mock,
mock_remove: mock.Mock,
mock_sleep: mock.Mock) -> None:
pid_file = get_temp_file('pid')
mock_sleep.side_effect = KeyboardInterrupt()
mock_args = mock_init_parser.return_value.parse_args.return_value
self.mock_default_args(mock_args)
mock_args.pid_file = pid_file
main(['--pid-file', pid_file])
mock_init_parser.assert_called()
mock_acceptor_pool.assert_called()
mock_acceptor_pool.return_value.setup.assert_called()
mock_open.assert_called_with(pid_file, 'wb')
mock_open.return_value.__enter__.return_value.write.assert_called_with(
bytes_(os.getpid()))
mock_exists.assert_called_with(pid_file)
mock_remove.assert_called_with(pid_file)
@mock.patch('time.sleep')
@mock.patch('proxy.proxy.Flags')
@mock.patch('proxy.proxy.AcceptorPool')
def test_basic_auth(
self,
mock_acceptor_pool: mock.Mock,
mock_flags: mock.Mock,
mock_sleep: mock.Mock) -> None:
mock_sleep.side_effect = KeyboardInterrupt()
input_args = ['--basic-auth', 'user:pass']
flags = Flags.initialize(input_args=input_args)
mock_flags.initialize = lambda *args, **kwargs: flags
main(input_args=input_args)
mock_acceptor_pool.assert_called_with(
flags=flags,
work_klass=HttpProtocolHandler)
self.assertEqual(
flags.auth_code,
b'Basic dXNlcjpwYXNz')
@mock.patch('time.sleep')
@mock.patch('builtins.print')
@mock.patch('proxy.proxy.Flags')
@mock.patch('proxy.proxy.AcceptorPool')
@mock.patch('proxy.proxy.Flags.is_py3')
def test_main_py3_runs(
self,
mock_is_py3: mock.Mock,
mock_acceptor_pool: mock.Mock,
mock_flags: mock.Mock,
mock_print: mock.Mock,
mock_sleep: mock.Mock) -> None:
mock_sleep.side_effect = KeyboardInterrupt()
input_args = ['--basic-auth', 'user:pass']
flags = Flags.initialize(input_args=input_args)
mock_flags.initialize = lambda *args, **kwargs: flags
mock_is_py3.return_value = True
main(num_workers=1)
mock_is_py3.assert_called()
mock_print.assert_not_called()
mock_acceptor_pool.assert_called()
mock_acceptor_pool.return_value.setup.assert_called()
@mock.patch('builtins.print')
@mock.patch('proxy.proxy.Flags.is_py3')
def test_main_py2_exit(
self,
mock_is_py3: mock.Mock,
mock_print: mock.Mock) -> None:
mock_is_py3.return_value = False
with self.assertRaises(SystemExit) as e:
main(num_workers=1)
mock_print.assert_called_with(
'DEPRECATION: "develop" branch no longer supports Python 2.7. Kindly upgrade to Python 3+. '
'If for some reasons you cannot upgrade, consider using "master" branch or simply '
'"pip install proxy.py==0.3".'
'\n\n'
'DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. '
'Please upgrade your Python as Python 2.7 won\'t be maintained after that date. '
'A future version of pip will drop support for Python 2.7.'
)
self.assertEqual(e.exception.code, 1)
mock_is_py3.assert_called()
@mock.patch('builtins.print')
def test_main_version(
self,
mock_print: mock.Mock) -> None:
with self.assertRaises(SystemExit) as e:
main(['--version'])
mock_print.assert_called_with(__version__)
self.assertEqual(e.exception.code, 0)