From ae0551e3fa00e22e9e539c088c41d868547f2442 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Fri, 21 Feb 2014 03:26:27 +0530 Subject: [PATCH] refactor + add tests --- Makefile | 2 +- proxy.py | 47 ++++---- tests.py | 327 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 356 insertions(+), 20 deletions(-) create mode 100644 tests.py diff --git a/Makefile b/Makefile index 201a7593..422ea38b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ clean: find . -name '*~' -exec rm -f {} + test: - nosetests -v --with-coverage --cover-package=proxy --cover-erase --cover-html --nocapture + python tests.py -v package: python setup.py sdist diff --git a/proxy.py b/proxy.py index 52af804b..8eca9a14 100644 --- a/proxy.py +++ b/proxy.py @@ -210,12 +210,12 @@ class HttpParser(object): return line, data class Connection(object): - """TCP connection abstraction.""" + """TCP server/client connection abstraction.""" def __init__(self, what): self.buffer = '' self.closed = False - self.what = what + self.what = what # server or client def send(self, data): return self.conn.send(data) @@ -251,7 +251,7 @@ class Connection(object): logger.debug('flushed %d bytes to %s' % (sent, self.what)) class Server(Connection): - """Established connection to destination server.""" + """Establish connection to destination server.""" def __init__(self, host, port): super(Server, self).__init__('server') @@ -473,22 +473,20 @@ class Proxy(multiprocessing.Process): self._access_log() logger.debug('Closing proxy for connection %r at address %r' % (self.client.conn, self.client.addr)) -class Http(object): - """HTTP server implementation. - - Listens on configured (host, port) and spawns a new process - for handling every accepted HTTP connection. Spawned process - takes care of proxying the HTTP request. - """ +class TCP(object): + """TCP server implementation.""" def __init__(self, hostname='127.0.0.1', port=8899, backlog=100): self.hostname = hostname self.port = port self.backlog = backlog + def handle(self, client): + raise NotImplementedError() + def run(self): try: - logger.info('Starting proxy server on port %d' % self.port) + logger.info('Starting server on port %d' % self.port) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.hostname, self.port)) @@ -497,33 +495,44 @@ class Http(object): conn, addr = self.socket.accept() logger.debug('Accepted connection %r at address %r' % (conn, addr)) client = Client(conn, addr) - proc = Proxy(client) - proc.daemon = True - proc.start() - logger.debug('Started process %r to handle connection %r' % (proc, conn)) + self.handle(client) except Exception as e: logger.exception('Exception while running the server %r' % e) finally: logger.info('Closing server socket') self.socket.close() +class HTTP(TCP): + """HTTP proxy server implementation. + + Spawns new process to proxy accepted client connection. + """ + + def handle(self, client): + proc = Proxy(client) + proc.daemon = True + proc.start() + logger.debug('Started process %r to handle connection %r' % (proc, client.conn)) + def main(): parser = argparse.ArgumentParser( description='proxy.py v%s' % __version__, epilog='Having difficulty using proxy.py? Report at: %s/issues/new' % __homepage__ ) + parser.add_argument('--hostname', default='127.0.0.1', help='Default: 127.0.0.1') parser.add_argument('--port', default='8899', help='Default: 8899') parser.add_argument('--log-level', default='INFO', help='DEBUG, INFO, WARNING, ERROR, CRITICAL') args = parser.parse_args() - hostname = args.hostname - port = int(args.port) logging.basicConfig(level=getattr(logging, args.log_level), format='%(asctime)s - %(process)d - %(message)s') + hostname = args.hostname + port = int(args.port) + try: - http = Http(hostname, port) - http.run() + proxy = HTTP(hostname, port) + proxy.run() except KeyboardInterrupt: pass diff --git a/tests.py b/tests.py new file mode 100644 index 00000000..32d7ec42 --- /dev/null +++ b/tests.py @@ -0,0 +1,327 @@ +import unittest +import proxy +from proxy import * + +class TestChunkParser(unittest.TestCase): + + def setUp(self): + self.parser = ChunkParser() + + def test_chunk_parse(self): + self.parser.parse(''.join([ + '4\r\n', + 'Wiki\r\n', + '5\r\n', + 'pedia\r\n', + 'E\r\n', + ' in\r\n\r\nchunks.\r\n', + '0\r\n', + '\r\n' + ])) + self.assertEqual(self.parser.chunk, '') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, 'Wikipedia in\r\n\r\nchunks.') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_COMPLETE) + +class TestHttpParser(unittest.TestCase): + + def setUp(self): + self.parser = HttpParser() + + def test_get_full_parse(self): + raw = CRLF.join([ + "GET %s HTTP/1.1", + "Host: %s", + CRLF + ]) + self.parser.parse(raw % ('https://example.com/path/dir/?a=b&c=d#p=q', 'example.com')) + self.assertEqual(self.parser.build_url(), '/path/dir/?a=b&c=d#p=q') + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "example.com") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertDictContainsSubset({'host':('Host', 'example.com')}, self.parser.headers) + self.assertEqual(raw % ('/path/dir/?a=b&c=d#p=q', 'example.com'), self.parser.build(del_headers=['host'], add_headers=[('Host', 'example.com')])) + + def test_build_url_none(self): + self.assertEqual(self.parser.build_url(), '/None') + + def test_line_rcvd_to_rcving_headers_state_change(self): + self.parser.parse("GET http://localhost HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_INITIALIZED) + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + def test_get_partial_parse1(self): + self.parser.parse(CRLF.join([ + "GET http://localhost:8080 HTTP/1.1" + ])) + self.assertEqual(self.parser.method, None) + self.assertEqual(self.parser.url, None) + self.assertEqual(self.parser.version, None) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_INITIALIZED) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, 8080) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse("Host: localhost:8080") + self.assertDictEqual(self.parser.headers, dict()) + self.assertEqual(self.parser.buffer, "Host: localhost:8080") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse(CRLF*2) + self.assertDictContainsSubset({'host':('Host', 'localhost:8080')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_get_partial_parse2(self): + self.parser.parse(CRLF.join([ + "GET http://localhost:8080 HTTP/1.1", + "Host: " + ])) + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, 8080) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.buffer, "Host: ") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse("localhost:8080%s" % CRLF) + self.assertDictContainsSubset({'host': ('Host', 'localhost:8080')}, self.parser.headers) + self.assertEqual(self.parser.buffer, "") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse("Content-Type: text/plain%s" % CRLF) + self.assertEqual(self.parser.buffer, "") + self.assertDictContainsSubset({'content-type': ('Content-Type', 'text/plain')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_post_full_parse(self): + raw = CRLF.join([ + "POST %s HTTP/1.1", + "Host: localhost", + "Content-Length: 7", + "Content-Type: application/x-www-form-urlencoded%s" % CRLF, + "a=b&c=d" + ]) + self.parser.parse(raw % 'http://localhost') + self.assertEqual(self.parser.method, "POST") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertDictContainsSubset({'content-type': ('Content-Type', 'application/x-www-form-urlencoded')}, self.parser.headers) + self.assertDictContainsSubset({'content-length': ('Content-Length', '7')}, self.parser.headers) + self.assertEqual(self.parser.body, "a=b&c=d") + self.assertEqual(self.parser.buffer, "") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(len(self.parser.build()), len(raw % '/')) + + def test_post_partial_parse(self): + self.parser.parse(CRLF.join([ + "POST http://localhost HTTP/1.1", + "Host: localhost", + "Content-Length: 7", + "Content-Type: application/x-www-form-urlencoded" + ])) + self.assertEqual(self.parser.method, "POST") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_HEADERS_COMPLETE) + + self.parser.parse("a=b") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_BODY) + self.assertEqual(self.parser.body, "a=b") + self.assertEqual(self.parser.buffer, "") + + self.parser.parse("&c=d") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(self.parser.body, "a=b&c=d") + self.assertEqual(self.parser.buffer, "") + + def test_response_parse(self): + self.parser.type = HTTP_RESPONSE_PARSER + self.parser.parse(''.join([ + 'HTTP/1.1 301 Moved Permanently\r\n', + 'Location: http://www.google.com/\r\n', + 'Content-Type: text/html; charset=UTF-8\r\n', + 'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + 'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + 'Cache-Control: public, max-age=2592000\r\n', + 'Server: gws\r\n', + 'Content-Length: 219\r\n', + 'X-XSS-Protection: 1; mode=block\r\n', + 'X-Frame-Options: SAMEORIGIN\r\n\r\n', + '\n301 Moved', + '\n

301 Moved

\nThe document has moved\nhere.\r\n\r\n' + ])) + self.assertEqual(self.parser.code, '301') + self.assertEqual(self.parser.reason, 'Moved Permanently') + self.assertEqual(self.parser.version, 'HTTP/1.1') + self.assertEqual(self.parser.body, '\n301 Moved\n

301 Moved

\nThe document has moved\nhere.\r\n\r\n') + self.assertDictContainsSubset({'content-length': ('Content-Length', '219')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_response_partial_parse(self): + self.parser.type = HTTP_RESPONSE_PARSER + self.parser.parse(''.join([ + 'HTTP/1.1 301 Moved Permanently\r\n', + 'Location: http://www.google.com/\r\n', + 'Content-Type: text/html; charset=UTF-8\r\n', + 'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + 'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + 'Cache-Control: public, max-age=2592000\r\n', + 'Server: gws\r\n', + 'Content-Length: 219\r\n', + 'X-XSS-Protection: 1; mode=block\r\n', + 'X-Frame-Options: SAMEORIGIN\r\n' + ])) + self.assertDictContainsSubset({'x-frame-options': ('X-Frame-Options', 'SAMEORIGIN')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + self.parser.parse('\r\n') + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_HEADERS_COMPLETE) + self.parser.parse('\n301 Moved') + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_BODY) + self.parser.parse('\n

301 Moved

\nThe document has moved\nhere.\r\n\r\n') + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_chunked_response_parse(self): + self.parser.type = HTTP_RESPONSE_PARSER + self.parser.parse(''.join([ + 'HTTP/1.1 200 OK\r\n', + 'Content-Type: application/json\r\n', + 'Date: Wed, 22 May 2013 15:08:15 GMT\r\n', + 'Server: gunicorn/0.16.1\r\n', + 'transfer-encoding: chunked\r\n', + 'Connection: keep-alive\r\n\r\n', + '4\r\n', + 'Wiki\r\n', + '5\r\n', + 'pedia\r\n', + 'E\r\n', + ' in\r\n\r\nchunks.\r\n', + '0\r\n', + '\r\n' + ])) + self.assertEqual(self.parser.body, 'Wikipedia in\r\n\r\nchunks.') + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + +class MockConnection(object): + + def __init__(self, buffer=''): + self.buffer = buffer + + def recv(self, bytes=8192): + data = self.buffer[:bytes] + self.buffer = self.buffer[bytes:] + return data + + def send(self, data): + return len(data) + + def queue(self, data): + self.buffer += data + +class TestProxy(unittest.TestCase): + + def setUp(self): + self._conn = MockConnection() + self._addr = ('127.0.0.1', 54382) + self.proxy = Proxy(Client(self._conn, self._addr)) + + def test_http_get(self): + self.proxy.client.conn.queue("GET http://httpbin.org/get HTTP/1.1%s" % CRLF) + self.proxy._process_request(self.proxy.client.recv()) + self.assertNotEqual(self.proxy.request.state, HTTP_PARSER_STATE_COMPLETE) + + self.proxy.client.conn.queue(CRLF.join([ + "User-Agent: curl/7.27.0", + "Host: httpbin.org", + "Accept: */*", + "Proxy-Connection: Keep-Alive", + CRLF + ])) + + self.proxy._process_request(self.proxy.client.recv()) + self.assertEqual(self.proxy.request.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(self.proxy.server.addr, ("httpbin.org", 80)) + + self.proxy.server.flush() + self.assertEqual(self.proxy.server.buffer_size(), 0) + + data = self.proxy.server.recv() + while data: + self.proxy._process_response(data) + if self.proxy.response.state == HTTP_PARSER_STATE_COMPLETE: + break + data = self.proxy.server.recv() + + self.assertEqual(self.proxy.response.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(int(self.proxy.response.code), 200) + + def test_https_get(self): + self.proxy.client.conn.queue(CRLF.join([ + "CONNECT httpbin.org:80 HTTP/1.1", + "Host: httpbin.org:80", + "User-Agent: curl/7.27.0", + "Proxy-Connection: Keep-Alive", + CRLF + ])) + self.proxy._process_request(self.proxy.client.recv()) + self.assertFalse(self.proxy.server == None) + self.assertEqual(self.proxy.client.buffer, self.proxy.connection_established_pkt) + + parser = HttpParser(HTTP_RESPONSE_PARSER) + parser.parse(self.proxy.client.buffer) + self.assertEqual(parser.state, HTTP_PARSER_STATE_HEADERS_COMPLETE) + self.assertEqual(int(parser.code), 200) + + self.proxy.client.flush() + self.assertEqual(self.proxy.client.buffer_size(), 0) + + self.proxy.client.conn.queue(CRLF.join([ + "GET /user-agent HTTP/1.1", + "Host: httpbin.org", + "User-Agent: curl/7.27.0", + CRLF + ])) + self.proxy._process_request(self.proxy.client.recv()) + self.proxy.server.flush() + self.assertEqual(self.proxy.server.buffer_size(), 0) + + parser = HttpParser(HTTP_RESPONSE_PARSER) + data = self.proxy.server.recv() + while data: + parser.parse(data) + if parser.state == HTTP_PARSER_STATE_COMPLETE: + break + data = self.proxy.server.recv() + + self.assertEqual(parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(int(parser.code), 200) + + def test_proxy_connection_failed(self): + with self.assertRaises(ProxyConnectionFailed): + self.proxy._process_request(CRLF.join([ + "GET http://unknown.domain HTTP/1.1", + "Host: unknown.domain", + CRLF + ])) + +if __name__ == '__main__': + unittest.main()