diff --git a/Procfile b/Procfile new file mode 100644 index 00000000..4c58bb8c --- /dev/null +++ b/Procfile @@ -0,0 +1,2 @@ +# See https://devcenter.heroku.com/articles/procfile +web: python3 proxy.py --hostname 0.0.0.0 --port $PORT diff --git a/proxy.py b/proxy.py index 05acef09..4e3012ec 100755 --- a/proxy.py +++ b/proxy.py @@ -93,7 +93,7 @@ CRLF, COLON, WHITESPACE, COMMA = b'\r\n', b':', b' ', ',' PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version -class TcpConnection(object): +class TcpConnection: """TCP server/client connection abstraction.""" types = namedtuple('TcpConnectionTypes', ( @@ -171,7 +171,7 @@ class TcpClientConnection(TcpConnection): self.addr: Tuple[str, int] = addr -class TcpServer(object): +class TcpServer: """TcpServer server implementation. Inheritor MUST implement `handle` method. It accepts an instance of `TcpClientConnection`. @@ -311,7 +311,7 @@ class Worker(multiprocessing.Process): break -class ChunkParser(object): +class ChunkParser: """HTTP chunked encoding response parser.""" states = namedtuple('ChunkParserStates', ( @@ -361,7 +361,7 @@ class ChunkParser(object): return len(raw) > 0, raw -class HttpParser(object): +class HttpParser: """HTTP request/response parser.""" states = namedtuple('HttpParserStates', ( @@ -623,7 +623,7 @@ class HttpRequestRejected(HttpProtocolException): return CRLF.join(pkt) if len(pkt) > 0 else None -class HttpProtocolConfig(object): +class HttpProtocolConfig: """Holds various configuration values applicable to HttpProtocolHandler. This config class helps us avoid passing around bunch of key/value pairs across methods. @@ -643,7 +643,7 @@ class HttpProtocolConfig(object): self.disable_headers = disable_headers -class HttpProtocolBasePlugin(object): +class HttpProtocolBasePlugin: """Base HttpProtocolHandler Plugin class. Implement various lifecycle event methods to customize behavior.""" @@ -729,7 +729,7 @@ class ProxyAuthenticationFailed(HttpProtocolException): return self.RESPONSE_PKT -class HttpProxyBasePlugin(object): +class HttpProxyBasePlugin: """Base HttpProxyPlugin Plugin class. Implement various lifecycle event methods to customize behavior.""" diff --git a/tests.py b/tests.py index e40ad9c9..191647ec 100644 --- a/tests.py +++ b/tests.py @@ -17,6 +17,7 @@ import time import unittest import errno import proxy +from typing import Dict from contextlib import closing from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Thread @@ -302,11 +303,9 @@ class TestHttpParser(unittest.TestCase): def test_line_rcvd_to_rcving_headers_state_change(self): self.parser.parse(b'GET http://localhost HTTP/1.1') - self.assertEqual(self.parser.state, proxy.HttpParser.states.INITIALIZED) - self.parser.parse(proxy.CRLF) - self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD) - self.parser.parse(proxy.CRLF) - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) + self.assert_state_change_with_crlf(proxy.HttpParser.states.INITIALIZED, + proxy.HttpParser.states.LINE_RCVD, + proxy.HttpParser.states.RCVING_HEADERS) def test_get_partial_parse1(self): self.parser.parse(proxy.CRLF.join([ @@ -379,6 +378,14 @@ class TestHttpParser(unittest.TestCase): self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) self.assertEqual(len(self.parser.build()), len(raw % b'/')) + def assert_state_change_with_crlf(self, initial_state: proxy.HttpParser.states, + next_state: proxy.HttpParser.states, final_state: proxy.HttpParser.states): + self.assertEqual(self.parser.state, initial_state) + self.parser.parse(proxy.CRLF) + self.assertEqual(self.parser.state, next_state) + self.parser.parse(proxy.CRLF) + self.assertEqual(self.parser.state, final_state) + def test_post_partial_parse(self): self.parser.parse(proxy.CRLF.join([ b'POST http://localhost HTTP/1.1', @@ -390,13 +397,9 @@ class TestHttpParser(unittest.TestCase): self.assertEqual(self.parser.url.hostname, b'localhost') self.assertEqual(self.parser.url.port, None) self.assertEqual(self.parser.version, b'HTTP/1.1') - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) - - self.parser.parse(proxy.CRLF) - self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_HEADERS) - - self.parser.parse(proxy.CRLF) - self.assertEqual(self.parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) + self.assert_state_change_with_crlf(proxy.HttpParser.states.RCVING_HEADERS, + proxy.HttpParser.states.RCVING_HEADERS, + proxy.HttpParser.states.HEADERS_COMPLETE) self.parser.parse(b'a=b') self.assertEqual(self.parser.state, proxy.HttpParser.states.RCVING_BODY) @@ -540,6 +543,10 @@ class TestHttpParser(unittest.TestCase): self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.') self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE) + def assertDictContainsSubset(self, subset: Dict, dictionary: Dict): + for k in subset.keys(): + self.assertTrue(k in dictionary) + class MockTcpConnection(object): @@ -616,24 +623,24 @@ class TestProxy(unittest.TestCase): b'Proxy-Connection: Keep-Alive', proxy.CRLF ])) - self.proxy.run_once() - self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) - mock_server_connection.assert_called_once() - server.connect.assert_called_once() - server.closed = False - server.queue.assert_called_once_with(proxy.CRLF.join([ - b'GET / HTTP/1.1', - b'User-Agent: proxy.py/%s' % proxy.version, - b'Host: localhost:%d' % self.http_server_port, - b'Accept: */*', - b'Via: %s' % b'1.1 proxy.py v%s' % proxy.version, - b'Connection: Close', - proxy.CRLF - ])) - + self.assert_data_queued(mock_server_connection, server) self.proxy.run_once() server.flush.assert_called_once() + def assert_tunnel_response(self, mock_server_connection, server): + self.proxy.run_once() + self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) + self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + mock_server_connection.assert_called_once() + server.connect.assert_called_once() + server.queue.assert_not_called() + server.closed = False + + parser = proxy.HttpParser(proxy.HttpParser.types.RESPONSE_PARSER) + parser.parse(self.proxy.client.buffer) + self.assertEqual(parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) + self.assertEqual(int(parser.code), 200) + @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') def test_http_tunnel(self, mock_server_connection, mock_select): @@ -650,37 +657,11 @@ class TestProxy(unittest.TestCase): b'Proxy-Connection: Keep-Alive', proxy.CRLF ])) - self.proxy.run_once() - self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) - self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - mock_server_connection.assert_called_once() - server.connect.assert_called_once() - server.queue.assert_not_called() - server.closed = False - - parser = proxy.HttpParser(proxy.HttpParser.types.RESPONSE_PARSER) - parser.parse(self.proxy.client.buffer) - self.assertEqual(parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) - self.assertEqual(int(parser.code), 200) + self.assert_tunnel_response(mock_server_connection, server) # Dispatch tunnel established response to client self.proxy.run_once() - self.assertEqual(self.proxy.client.buffer_size(), 0) - - self.proxy.client.conn.queue(proxy.CRLF.join([ - b'GET / HTTP/1.1', - b'Host: localhost:%d' % self.http_server_port, - b'User-Agent: proxy.py/%s' % proxy.version, - proxy.CRLF - ])) - self.proxy.run_once() - server.queue.assert_called_once_with(proxy.CRLF.join([ - b'GET / HTTP/1.1', - b'Host: localhost:%d' % self.http_server_port, - b'User-Agent: proxy.py/%s' % proxy.version, - proxy.CRLF - ])) - server.flush.assert_not_called() + self.assert_data_queued_to_server(server) self.proxy.run_once() self.assertEqual(server.queue.call_count, 1) @@ -740,12 +721,14 @@ class TestProxy(unittest.TestCase): b'Proxy-Authorization: Basic dXNlcjpwYXNz', proxy.CRLF ])) + self.assert_data_queued(mock_server_connection, server) + + def assert_data_queued(self, mock_server_connection, server): self.proxy.run_once() self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.COMPLETE) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False - server.queue.assert_called_once_with(proxy.CRLF.join([ b'GET / HTTP/1.1', b'User-Agent: proxy.py/%s' % proxy.version, @@ -775,18 +758,14 @@ class TestProxy(unittest.TestCase): b'Proxy-Authorization: Basic dXNlcjpwYXNz', proxy.CRLF ])) - self.proxy.run_once() - self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) - self.assertEqual(self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - mock_server_connection.assert_called_once() - server.connect.assert_called_once() - server.closed = False - - parser = proxy.HttpParser(proxy.HttpParser.types.RESPONSE_PARSER) - parser.parse(self.proxy.client.buffer) - self.assertEqual(parser.state, proxy.HttpParser.states.HEADERS_COMPLETE) - self.assertEqual(int(parser.code), 200) + self.assert_tunnel_response(mock_server_connection, server) self.proxy.client.flush() + self.assert_data_queued_to_server(server) + + self.proxy.run_once() + server.flush.assert_called_once() + + def assert_data_queued_to_server(self, server): self.assertEqual(self.proxy.client.buffer_size(), 0) self.proxy.client.conn.queue(proxy.CRLF.join([ @@ -802,9 +781,7 @@ class TestProxy(unittest.TestCase): b'User-Agent: proxy.py/%s' % proxy.version, proxy.CRLF ])) - - self.proxy.run_once() - server.flush.assert_called_once() + server.flush.assert_not_called() class TestWorker(unittest.TestCase):