proxy.py/tests.py

963 lines
40 KiB
Python

# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python.
:copyright: (c) 2013-2020 by Abhinav Singh.
:license: BSD, see LICENSE for more details.
"""
import base64
import logging
import multiprocessing
import os
import socket
import time
import unittest
import errno
import proxy
from contextlib import closing
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Thread
from unittest import mock
if os.name != 'nt':
import resource
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
def get_available_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(('', 0))
_, port = sock.getsockname()
return port
class TestTcpConnection(unittest.TestCase):
def testHandlesIOError(self):
self.conn = proxy.TcpConnection(proxy.TcpConnection.types.CLIENT)
_conn = mock.MagicMock()
_conn.recv.side_effect = IOError()
self.conn.conn = _conn
with mock.patch('proxy.logger') as mock_logger:
self.conn.recv()
mock_logger.exception.assert_called()
logging.info(mock_logger.exception.call_args[0][0].startswith('Exception while receiving from connection'))
def testHandlesConnReset(self):
self.conn = proxy.TcpConnection(proxy.TcpConnection.types.CLIENT)
_conn = mock.MagicMock()
e = IOError()
e.errno = errno.ECONNRESET
_conn.recv.side_effect = e
self.conn.conn = _conn
with mock.patch('proxy.logger') as mock_logger:
self.conn.recv()
mock_logger.exception.assert_not_called()
mock_logger.debug.assert_called()
self.assertEqual(mock_logger.debug.call_args[0][0], '%r' % e)
def testClosesIfNotClosed(self):
self.conn = proxy.TcpConnection(proxy.TcpConnection.types.CLIENT)
_conn = mock.MagicMock()
self.conn.conn = _conn
self.conn.close()
_conn.close.assert_called()
self.assertTrue(self.conn.closed)
def testNoOpIfAlreadyClosed(self):
self.conn = proxy.TcpConnection(proxy.TcpConnection.types.CLIENT)
_conn = mock.MagicMock()
self.conn.conn = _conn
self.conn.closed = True
self.conn.close()
_conn.close.assert_not_called()
self.assertTrue(self.conn.closed)
@mock.patch('socket.create_connection')
def testTcpServerClosesConnOnGC(self, mock_create_connection):
conn = mock.MagicMock()
mock_create_connection.return_value = conn
self.conn = proxy.TcpServerConnection(proxy.DEFAULT_IPV4_HOSTNAME, proxy.DEFAULT_PORT)
self.conn.connect()
del self.conn
conn.close.assert_called()
@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), 'Opening sockets not allowed on Travis')
class TestTcpServer(unittest.TestCase):
ipv4_port = None
ipv6_port = None
ipv4_server = None
ipv6_server = None
ipv4_thread = None
ipv6_thread = None
class _TestTcpServer(proxy.TcpServer):
def handle(self, client):
data = client.recv(proxy.DEFAULT_BUFFER_SIZE)
assert data == b'HELLO'
client.conn.sendall(b'WORLD')
client.close()
@classmethod
def setUpClass(cls):
cls.ipv4_port = get_available_port()
cls.ipv6_port = get_available_port()
cls.ipv4_server = TestTcpServer._TestTcpServer(port=cls.ipv4_port, ipv4=True)
cls.ipv6_server = TestTcpServer._TestTcpServer(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=cls.ipv6_port,
ipv4=False)
cls.ipv4_thread = Thread(target=cls.ipv4_server.run)
cls.ipv6_thread = Thread(target=cls.ipv6_server.run)
cls.ipv4_thread.setDaemon(True)
cls.ipv6_thread.setDaemon(True)
cls.ipv4_thread.start()
cls.ipv6_thread.start()
@classmethod
def tearDownClass(cls):
cls.ipv4_server.stop()
cls.ipv4_thread.join()
def baseTestCase(self, ipv4=True):
while True:
sock = None
try:
sock = socket.socket(socket.AF_INET if ipv4 else socket.AF_INET6, socket.SOCK_STREAM, 0)
sock.connect((proxy.DEFAULT_IPV4_HOSTNAME if ipv4 else proxy.DEFAULT_IPV6_HOSTNAME,
self.ipv4_port if ipv4 else self.ipv6_port))
sock.sendall(b'HELLO')
data = sock.recv(proxy.DEFAULT_BUFFER_SIZE)
self.assertEqual(data, b'WORLD')
break
except ConnectionRefusedError:
time.sleep(0.1)
finally:
sock.close()
def testIpv4ClientConnection(self):
self.baseTestCase()
def testIpv6ClientConnection(self):
self.baseTestCase(ipv4=False)
class MockHttpProxy(object):
def __init__(self, client, **kwargs):
self.client = client
self.kwargs = kwargs
def setDaemon(self, _val):
pass
def start(self):
self.client.conn.sendall(proxy.CRLF.join([b'HTTP/1.1 200 OK', proxy.CRLF]))
self.client.conn.close()
def mock_tcp_proxy_side_effect(client, **kwargs):
return MockHttpProxy(client, **kwargs)
@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), 'Opening sockets not allowed on Travis')
class TestMultiCoreRequestDispatcher(unittest.TestCase):
tcp_port = None
tcp_server = None
tcp_thread = None
@mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect)
def testHttpProxyConnection(self, mock_tcp_proxy):
try:
self.tcp_port = get_available_port()
self.tcp_server = proxy.MultiCoreRequestDispatcher(hostname=proxy.DEFAULT_IPV4_HOSTNAME, port=self.tcp_port,
ipv4=True, num_workers=1)
self.tcp_thread = Thread(target=self.tcp_server.run)
self.tcp_thread.setDaemon(True)
self.tcp_thread.start()
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.connect((proxy.DEFAULT_IPV4_HOSTNAME, self.tcp_port))
sock.send(proxy.CRLF.join([
b'GET http://httpbin.org/get HTTP/1.1',
b'Host: httpbin.org',
proxy.CRLF
]))
data = sock.recv(proxy.DEFAULT_BUFFER_SIZE)
self.assertEqual(data, proxy.CRLF.join([b'HTTP/1.1 200 OK', proxy.CRLF]))
self.tcp_server.shutdown() # explicit early call worker shutdown to avoid resource leak warnings
break
except ConnectionRefusedError:
time.sleep(0.1)
finally:
sock.close()
finally:
self.tcp_server.stop()
self.tcp_thread.join()
class TestChunkParser(unittest.TestCase):
def setUp(self):
self.parser = proxy.ChunkParser()
def test_chunk_parse_basic(self):
self.parser.parse(b''.join([
b'4\r\n',
b'Wiki\r\n',
b'5\r\n',
b'pedia\r\n',
b'E\r\n',
b' in\r\n\r\nchunks.\r\n',
b'0\r\n',
b'\r\n'
]))
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.COMPLETE)
def test_chunk_parse_issue_27(self):
"""Case when data ends with the chunk size but without ending CRLF."""
self.parser.parse(b'3')
self.assertEqual(self.parser.chunk, b'3')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE)
self.parser.parse(b'\r\n')
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, 3)
self.assertEqual(self.parser.body, b'')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_DATA)
self.parser.parse(b'abc')
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'abc')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE)
self.parser.parse(b'\r\n')
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'abc')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE)
self.parser.parse(b'4\r\n')
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, 4)
self.assertEqual(self.parser.body, b'abc')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_DATA)
self.parser.parse(b'defg\r\n0')
self.assertEqual(self.parser.chunk, b'0')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'abcdefg')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.WAITING_FOR_SIZE)
self.parser.parse(b'\r\n\r\n')
self.assertEqual(self.parser.chunk, b'')
self.assertEqual(self.parser.size, None)
self.assertEqual(self.parser.body, b'abcdefg')
self.assertEqual(self.parser.state, proxy.ChunkParser.states.COMPLETE)
class TestHttpParser(unittest.TestCase):
def setUp(self):
self.parser = proxy.HttpParser(proxy.HttpParser.types.REQUEST_PARSER)
def test_build_header(self):
self.assertEqual(proxy.HttpParser.build_header(b'key', b'value'), b'key: value')
def test_split(self):
self.assertEqual(proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0\r\n\r\n'),
(b'CONNECT python.org:443 HTTP/1.0', b'\r\n'))
def test_split_false_line(self):
self.assertEqual(proxy.HttpParser.split(b'CONNECT python.org:443 HTTP/1.0'),
(False, b'CONNECT python.org:443 HTTP/1.0'))
def test_get_full_parse(self):
raw = proxy.CRLF.join([
b'GET %s HTTP/1.1',
b'Host: %s',
proxy.CRLF
])
self.parser.parse(raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', b'example.com'))
self.assertEqual(self.parser.build_url(), b'/path/dir/?a=b&c=d#p=q')
self.assertEqual(self.parser.method, b'GET')
self.assertEqual(self.parser.url.hostname, b'example.com')
self.assertEqual(self.parser.url.port, None)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
self.assertDictContainsSubset({b'host': (b'Host', b'example.com')}, self.parser.headers)
self.assertEqual(raw % (b'/path/dir/?a=b&c=d#p=q', b'example.com'),
self.parser.build(del_headers=[b'host'], add_headers=[(b'Host', b'example.com')]))
def test_build_url_none(self):
self.assertEqual(self.parser.build_url(), b'/None')
def test_line_rcvd_to_rcving_headers_state_change(self):
self.parser.parse(b'GET http://localhost HTTP/1.1')
self.assertEqual(self.parser.state, proxy.HttpParser.states.INITIALIZED)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
def test_get_partial_parse1(self):
self.parser.parse(proxy.CRLF.join([
b'GET http://localhost:8080 HTTP/1.1'
]))
self.assertEqual(self.parser.method, None)
self.assertEqual(self.parser.url, None)
self.assertEqual(self.parser.version, None)
self.assertEqual(self.parser.state, proxy.HttpParser.states.INITIALIZED)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.method, b'GET')
self.assertEqual(self.parser.url.hostname, b'localhost')
self.assertEqual(self.parser.url.port, 8080)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(b'Host: localhost:8080')
self.assertDictEqual(self.parser.headers, dict())
self.assertEqual(self.parser.buffer, b'Host: localhost:8080')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(proxy.CRLF * 2)
self.assertDictContainsSubset({b'host': (b'Host', b'localhost:8080')}, self.parser.headers)
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_get_partial_parse2(self):
self.parser.parse(proxy.CRLF.join([
b'GET http://localhost:8080 HTTP/1.1',
b'Host: '
]))
self.assertEqual(self.parser.method, b'GET')
self.assertEqual(self.parser.url.hostname, b'localhost')
self.assertEqual(self.parser.url.port, 8080)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.buffer, b'Host: ')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(b'localhost:8080' + proxy.CRLF)
self.assertDictContainsSubset({b'host': (b'Host', b'localhost:8080')}, self.parser.headers)
self.assertEqual(self.parser.buffer, b'')
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
self.parser.parse(b'Content-Type: text/plain' + proxy.CRLF)
self.assertEqual(self.parser.buffer, b'')
self.assertDictContainsSubset({b'content-type': (b'Content-Type', b'text/plain')}, self.parser.headers)
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_post_full_parse(self):
raw = proxy.CRLF.join([
b'POST %s HTTP/1.1',
b'Host: localhost',
b'Content-Length: 7',
b'Content-Type: application/x-www-form-urlencoded' + proxy.CRLF,
b'a=b&c=d'
])
self.parser.parse(raw % b'http://localhost')
self.assertEqual(self.parser.method, b'POST')
self.assertEqual(self.parser.url.hostname, b'localhost')
self.assertEqual(self.parser.url.port, None)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertDictContainsSubset({b'content-type': (b'Content-Type', b'application/x-www-form-urlencoded')},
self.parser.headers)
self.assertDictContainsSubset({b'content-length': (b'Content-Length', b'7')}, self.parser.headers)
self.assertEqual(self.parser.body, b'a=b&c=d')
self.assertEqual(self.parser.buffer, b'')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
self.assertEqual(len(self.parser.build()), len(raw % b'/'))
def test_post_partial_parse(self):
self.parser.parse(proxy.CRLF.join([
b'POST http://localhost HTTP/1.1',
b'Host: localhost',
b'Content-Length: 7',
b'Content-Type: application/x-www-form-urlencoded'
]))
self.assertEqual(self.parser.method, b'POST')
self.assertEqual(self.parser.url.hostname, b'localhost')
self.assertEqual(self.parser.url.port, None)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE)
self.parser.parse(b'a=b')
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_BODY)
self.assertEqual(self.parser.body, b'a=b')
self.assertEqual(self.parser.buffer, b'')
self.parser.parse(b'&c=d')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
self.assertEqual(self.parser.body, b'a=b&c=d')
self.assertEqual(self.parser.buffer, b'')
def test_connect_request_without_host_header_request_parse(self):
"""Case where clients can send CONNECT request without a Host header field.
Example:
1. pip3 --proxy http://localhost:8899 install <package name>
Uses HTTP/1.0, Host header missing with CONNECT requests
2. Android Emulator
Uses HTTP/1.1, Host header missing with CONNECT requests
See https://github.com/abhinavsingh/proxy.py/issues/5 for details.
"""
self.parser.parse(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n')
self.assertEqual(self.parser.method, b'CONNECT')
self.assertEqual(self.parser.version, b'HTTP/1.0')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_request_parse_without_content_length(self):
"""Case when incoming request doesn't contain a content-length header.
From http://w3-org.9356.n7.nabble.com/POST-with-empty-body-td103965.html
'A POST with no content-length and no body is equivalent to a POST with Content-Length: 0
and nothing following, as could perfectly happen when you upload an empty file for instance.'
See https://github.com/abhinavsingh/proxy.py/issues/20 for details.
"""
self.parser.parse(proxy.CRLF.join([
b'POST http://localhost HTTP/1.1',
b'Host: localhost',
b'Content-Type: application/x-www-form-urlencoded',
proxy.CRLF
]))
self.assertEqual(self.parser.method, b'POST')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_response_parse_without_content_length(self):
"""Case when server response doesn't contain a content-length header for non-chunk response types.
HttpParser by itself has no way to know if more data should be expected.
In example below, parser reaches state HttpParser.states.HEADERS_COMPLETE
and it is responsibility of callee to change state to HttpParser.states.COMPLETE
when server stream closes.
See https://github.com/abhinavsingh/proxy.py/issues/20 for details.
"""
self.parser.type = proxy.HttpParser.types.RESPONSE_PARSER
self.parser.parse(b'HTTP/1.0 200 OK' + proxy.CRLF)
self.assertEqual(self.parser.code, b'200')
self.assertEqual(self.parser.version, b'HTTP/1.0')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(proxy.CRLF.join([
b'Server: BaseHTTP/0.3 Python/2.7.10',
b'Date: Thu, 13 Dec 2018 16:24:09 GMT',
proxy.CRLF
]))
self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE)
def test_response_parse(self):
self.parser.type = proxy.HttpParser.types.RESPONSE_PARSER
self.parser.parse(b''.join([
b'HTTP/1.1 301 Moved Permanently\r\n',
b'Location: http://www.google.com/\r\n',
b'Content-Type: text/html; charset=UTF-8\r\n',
b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n',
b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n',
b'Cache-Control: public, max-age=2592000\r\n',
b'Server: gws\r\n',
b'Content-Length: 219\r\n',
b'X-XSS-Protection: 1; mode=block\r\n',
b'X-Frame-Options: SAMEORIGIN\r\n\r\n',
b'<HTML><HEAD><meta http-equiv="content-type" content="text/html;charset=utf-8">\n' +
b'<TITLE>301 Moved</TITLE></HEAD>',
b'<BODY>\n<H1>301 Moved</H1>\nThe document has moved\n' +
b'<A HREF="http://www.google.com/">here</A>.\r\n</BODY></HTML>\r\n'
]))
self.assertEqual(self.parser.code, b'301')
self.assertEqual(self.parser.reason, b'Moved Permanently')
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.body,
b'<HTML><HEAD><meta http-equiv="content-type" content="text/html;charset=utf-8">\n' +
b'<TITLE>301 Moved</TITLE></HEAD><BODY>\n<H1>301 Moved</H1>\nThe document has moved\n' +
b'<A HREF="http://www.google.com/">here</A>.\r\n</BODY></HTML>\r\n')
self.assertDictContainsSubset({b'content-length': (b'Content-Length', b'219')}, self.parser.headers)
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_response_partial_parse(self):
self.parser.type = proxy.HttpParser.types.RESPONSE_PARSER
self.parser.parse(b''.join([
b'HTTP/1.1 301 Moved Permanently\r\n',
b'Location: http://www.google.com/\r\n',
b'Content-Type: text/html; charset=UTF-8\r\n',
b'Date: Wed, 22 May 2013 14:07:29 GMT\r\n',
b'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n',
b'Cache-Control: public, max-age=2592000\r\n',
b'Server: gws\r\n',
b'Content-Length: 219\r\n',
b'X-XSS-Protection: 1; mode=block\r\n',
b'X-Frame-Options: SAMEORIGIN\r\n'
]))
self.assertDictContainsSubset({b'x-frame-options': (b'X-Frame-Options', b'SAMEORIGIN')}, self.parser.headers)
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS)
self.parser.parse(b'\r\n')
self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE)
self.parser.parse(
b'<HTML><HEAD><meta http-equiv="content-type" content="text/html;charset=utf-8">\n' +
b'<TITLE>301 Moved</TITLE></HEAD>')
self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_BODY)
self.parser.parse(
b'<BODY>\n<H1>301 Moved</H1>\nThe document has moved\n' +
b'<A HREF="http://www.google.com/">here</A>.\r\n</BODY></HTML>\r\n')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
def test_chunked_response_parse(self):
self.parser.type = proxy.HttpParser.types.RESPONSE_PARSER
self.parser.parse(b''.join([
b'HTTP/1.1 200 OK\r\n',
b'Content-Type: application/json\r\n',
b'Date: Wed, 22 May 2013 15:08:15 GMT\r\n',
b'Server: gunicorn/0.16.1\r\n',
b'transfer-encoding: chunked\r\n',
b'Connection: keep-alive\r\n\r\n',
b'4\r\n',
b'Wiki\r\n',
b'5\r\n',
b'pedia\r\n',
b'E\r\n',
b' in\r\n\r\nchunks.\r\n',
b'0\r\n',
b'\r\n'
]))
self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.')
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)
class MockTcpConnection(object):
def __init__(self, b=b''):
self.buffer = b
def recv(self, b=8192):
data = self.buffer[:b]
self.buffer = self.buffer[b:]
return data
def send(self, data):
return len(data)
def queue(self, data):
self.buffer += data
class HTTPRequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
# TODO(abhinavsingh): Proxy should work just fine even without content-length header
self.send_header('content-length', 2)
self.end_headers()
self.wfile.write(b'OK')
class TestProxy(unittest.TestCase):
http_server = None
http_server_port = None
http_server_thread = None
config = None
@classmethod
def setUpClass(cls):
cls.http_server_port = get_available_port()
cls.http_server = HTTPServer(('127.0.0.1', cls.http_server_port), HTTPRequestHandler)
cls.http_server_thread = Thread(target=cls.http_server.serve_forever)
cls.http_server_thread.setDaemon(True)
cls.http_server_thread.start()
cls.config = proxy.HttpProtocolConfig()
cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
@classmethod
def tearDownClass(cls):
cls.http_server.shutdown()
cls.http_server.server_close()
cls.http_server_thread.join()
def setUp(self):
self._conn = MockTcpConnection()
self._addr = ('127.0.0.1', 54382)
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config)
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
def test_http_get(self, mock_server_connection, mock_select):
server = mock_server_connection.return_value
server.connect.return_value = True
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
# Send request line
self.proxy.client.conn.queue((b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) + proxy.CRLF)
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.LINE_RCVD)
self.assertNotEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE)
# Send headers and blank line, thus completing HTTP request
self.proxy.client.conn.queue(proxy.CRLF.join([
b'User-Agent: proxy.py/%s' % proxy.version,
b'Host: localhost:%d' % self.http_server_port,
b'Accept: */*',
b'Proxy-Connection: Keep-Alive',
proxy.CRLF
]))
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
server.closed = False
server.queue.assert_called_once_with(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'User-Agent: proxy.py/%s' % proxy.version,
b'Host: localhost:%d' % self.http_server_port,
b'Accept: */*',
b'Via: %s' % b'1.1 proxy.py v%s' % proxy.version,
b'Connection: Close',
proxy.CRLF
]))
self.proxy.run_once()
server.flush.assert_called_once()
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
def test_http_tunnel(self, mock_server_connection, mock_select):
server = mock_server_connection.return_value
server.connect.return_value = True
server.has_buffer.side_effect = [False, False, False, True]
mock_select.side_effect = [([self._conn], [], []), ([], [self._conn], []),
([self._conn], [], []), ([], [server.conn], [])]
self.proxy.client.conn.queue(proxy.CRLF.join([
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
b'Proxy-Connection: Keep-Alive',
proxy.CRLF
]))
self.proxy.run_once()
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
server.queue.assert_not_called()
server.closed = False
parser = proxy.HttpParser(proxy.HttpParser.types.RESPONSE_PARSER)
parser.parse(self.proxy.client.buffer)
self.assertEqual(parser.state, proxy.HttpParser.states.HEADERS_COMPLETE)
self.assertEqual(int(parser.code), 200)
# Dispatch tunnel established response to client
self.proxy.run_once()
self.assertEqual(self.proxy.client.buffer_size(), 0)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
proxy.CRLF
]))
self.proxy.run_once()
server.queue.assert_called_once_with(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
proxy.CRLF
]))
server.flush.assert_not_called()
self.proxy.run_once()
self.assertEqual(server.queue.call_count, 1)
server.flush.assert_called_once()
@mock.patch('select.select')
def test_proxy_connection_failed(self, mock_select):
mock_select.return_value = ([self._conn], [], [])
self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET http://unknown.domain HTTP/1.1',
b'Host: unknown.domain',
proxy.CRLF
]))
with self.assertRaises(proxy.ProxyConnectionFailed):
self.proxy.run_once()
@mock.patch('select.select')
def test_proxy_authentication_failed(self, mock_select):
mock_select.return_value = ([self._conn], [], [])
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET http://abhinavsingh.com HTTP/1.1',
b'Host: abhinavsingh.com',
proxy.CRLF
]))
with self.assertRaises(proxy.ProxyAuthenticationFailed):
self.proxy.run_once()
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select):
mock_select.return_value = ([self._conn], [], [])
server = mock_server_connection.return_value
server.connect.return_value = True
client = proxy.TcpClientConnection(self._conn, self._addr)
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(client, config=config)
self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port)
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED)
self.proxy.client.conn.queue(proxy.CRLF)
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.LINE_RCVD)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'User-Agent: proxy.py/%s' % proxy.version,
b'Host: localhost:%d' % self.http_server_port,
b'Accept: */*',
b'Proxy-Connection: Keep-Alive',
b'Proxy-Authorization: Basic dXNlcjpwYXNz',
proxy.CRLF
]))
self.proxy.run_once()
self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
server.closed = False
server.queue.assert_called_once_with(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'User-Agent: proxy.py/%s' % proxy.version,
b'Host: localhost:%d' % self.http_server_port,
b'Accept: */*',
b'Via: %s' % b'1.1 proxy.py v%s' % proxy.version,
b'Connection: Close',
proxy.CRLF
]))
@mock.patch('select.select')
@mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_tunnel(self, mock_server_connection, mock_select):
server = mock_server_connection.return_value
server.connect.return_value = True
mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])]
config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))
config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr),
config=config)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
b'Proxy-Connection: Keep-Alive',
b'Proxy-Authorization: Basic dXNlcjpwYXNz',
proxy.CRLF
]))
self.proxy.run_once()
self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None)
self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
mock_server_connection.assert_called_once()
server.connect.assert_called_once()
server.closed = False
parser = proxy.HttpParser(proxy.HttpParser.types.RESPONSE_PARSER)
parser.parse(self.proxy.client.buffer)
self.assertEqual(parser.state, proxy.HttpParser.states.HEADERS_COMPLETE)
self.assertEqual(int(parser.code), 200)
self.proxy.client.flush()
self.assertEqual(self.proxy.client.buffer_size(), 0)
self.proxy.client.conn.queue(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
proxy.CRLF
]))
self.proxy.run_once()
server.queue.assert_called_once_with(proxy.CRLF.join([
b'GET / HTTP/1.1',
b'Host: localhost:%d' % self.http_server_port,
b'User-Agent: proxy.py/%s' % proxy.version,
proxy.CRLF
]))
self.proxy.run_once()
server.flush.assert_called_once()
class TestWorker(unittest.TestCase):
def setUp(self):
self.queue = multiprocessing.Queue()
self.worker = proxy.Worker(self.queue)
@mock.patch('proxy.HttpProtocolHandler')
def test_shutdown_op(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run() # Worker should consume the prior shutdown operation
self.assertFalse(mock_http_proxy.called)
@mock.patch('proxy.HttpProtocolHandler')
def test_spawns_http_proxy_threads(self, mock_http_proxy):
self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None))
self.queue.put((proxy.Worker.operations.SHUTDOWN, None))
self.worker.run()
self.assertTrue(mock_http_proxy.called)
class TestHttpRequestRejected(unittest.TestCase):
def setUp(self):
self.request = proxy.HttpParser(proxy.HttpParser.types.REQUEST_PARSER)
def test_empty_response(self):
e = proxy.HttpRequestRejected()
self.assertEqual(e.response(self.request), None)
def test_status_code_response(self):
e = proxy.HttpRequestRejected(status_code=b'200 OK')
self.assertEqual(e.response(self.request), proxy.CRLF.join([
b'HTTP/1.1 200 OK',
proxy.PROXY_AGENT_HEADER,
proxy.CRLF
]))
def test_body_response(self):
e = proxy.HttpRequestRejected(status_code=b'404 NOT FOUND', body=b'Nothing here')
self.assertEqual(e.response(self.request), proxy.CRLF.join([
b'HTTP/1.1 404 NOT FOUND',
proxy.PROXY_AGENT_HEADER,
b'Content-Length: 12',
proxy.CRLF,
b'Nothing here'
]))
class TestMain(unittest.TestCase):
@mock.patch('proxy.HttpProtocolConfig')
@mock.patch('proxy.set_open_file_limit')
@mock.patch('proxy.MultiCoreRequestDispatcher')
def test_main(self, mock_multicore_dispatcher, mock_set_open_file_limit, mock_config):
proxy.main(['--basic-auth', 'user:pass'])
self.assertTrue(mock_set_open_file_limit.called)
mock_multicore_dispatcher.assert_called_with(
hostname=proxy.DEFAULT_IPV4_HOSTNAME,
port=proxy.DEFAULT_PORT,
ipv4=proxy.DEFAULT_IPV4,
backlog=proxy.DEFAULT_BACKLOG,
num_workers=proxy.DEFAULT_NUM_WORKERS,
config=mock_config.return_value)
mock_config.assert_called_with(
auth_code=b'Basic dXNlcjpwYXNz',
client_recvbuf_size=proxy.DEFAULT_CLIENT_RECVBUF_SIZE,
server_recvbuf_size=proxy.DEFAULT_SERVER_RECVBUF_SIZE,
pac_file=proxy.DEFAULT_PAC_FILE,
pac_file_url_path=proxy.DEFAULT_PAC_FILE_URL_PATH
)
@mock.patch('builtins.print')
@mock.patch('proxy.HttpProtocolConfig')
@mock.patch('proxy.set_open_file_limit')
@mock.patch('proxy.MultiCoreRequestDispatcher')
def test_main_version(self, mock_multicore_dispatcher, mock_set_open_file_limit, mock_config, mock_print):
with self.assertRaises(SystemExit):
proxy.main(['--version'])
mock_print.assert_called_with(proxy.text_(proxy.version))
mock_multicore_dispatcher.assert_not_called()
mock_set_open_file_limit.assert_not_called()
mock_config.assert_not_called()
@mock.patch('builtins.print')
@mock.patch('proxy.HttpProtocolConfig')
@mock.patch('proxy.set_open_file_limit')
@mock.patch('proxy.MultiCoreRequestDispatcher')
@mock.patch('proxy.is_py3')
def test_main_py3_runs(self, mock_is_py3, mock_multicore_dispatcher, mock_set_open_file_limit,
mock_config, mock_print):
mock_is_py3.return_value = True
proxy.main([])
mock_is_py3.assert_called()
mock_print.assert_not_called()
mock_multicore_dispatcher.assert_called()
mock_set_open_file_limit.assert_called()
mock_config.assert_called()
@mock.patch('builtins.print')
@mock.patch('proxy.HttpProtocolConfig')
@mock.patch('proxy.set_open_file_limit')
@mock.patch('proxy.MultiCoreRequestDispatcher')
@mock.patch('proxy.is_py3')
def test_main_py2_exit(self, mock_is_py3, mock_multicore_dispatcher, mock_set_open_file_limit,
mock_config, mock_print):
mock_is_py3.return_value = False
with self.assertRaises(SystemExit):
proxy.main([])
mock_print.assert_called_with('DEPRECATION')
mock_is_py3.assert_called()
mock_multicore_dispatcher.assert_not_called()
mock_set_open_file_limit.assert_not_called()
mock_config.assert_not_called()
def test_text(self):
self.assertEqual(proxy.text_(b'hello'), 'hello')
def test_text_nochange(self):
self.assertEqual(proxy.text_('hello'), 'hello')
def test_bytes(self):
self.assertEqual(proxy.bytes_('hello'), b'hello')
def test_bytes_nochange(self):
self.assertEqual(proxy.bytes_(b'hello'), b'hello')
@unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows')
@mock.patch('resource.getrlimit', return_value=(128, 1024))
@mock.patch('resource.setrlimit', return_value=None)
def test_set_open_file_limit(self, mock_set_rlimit, mock_get_rlimit):
proxy.set_open_file_limit(256)
mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE)
mock_set_rlimit.assert_called_with(resource.RLIMIT_NOFILE, (256, 1024))
@unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows')
@mock.patch('resource.getrlimit', return_value=(256, 1024))
@mock.patch('resource.setrlimit', return_value=None)
def test_set_open_file_limit_not_called(self, mock_set_rlimit, mock_get_rlimit):
proxy.set_open_file_limit(256)
mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE)
mock_set_rlimit.assert_not_called()
@unittest.skipIf(os.name == 'nt', 'Open file limit tests disabled for Windows')
@mock.patch('resource.getrlimit', return_value=(256, 1024))
@mock.patch('resource.setrlimit', return_value=None)
def test_set_open_file_limit_not_called1(self, mock_set_rlimit, mock_get_rlimit):
proxy.set_open_file_limit(1024)
mock_get_rlimit.assert_called_with(resource.RLIMIT_NOFILE)
mock_set_rlimit.assert_not_called()
if __name__ == '__main__':
proxy.UNDER_TEST = True
unittest.main()