diff --git a/Makefile b/Makefile index 201a7593..422ea38b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ clean: find . -name '*~' -exec rm -f {} + test: - nosetests -v --with-coverage --cover-package=proxy --cover-erase --cover-html --nocapture + python tests.py -v package: python setup.py sdist diff --git a/proxy.py b/proxy.py index 52af804b..8eca9a14 100644 --- a/proxy.py +++ b/proxy.py @@ -210,12 +210,12 @@ class HttpParser(object): return line, data class Connection(object): - """TCP connection abstraction.""" + """TCP server/client connection abstraction.""" def __init__(self, what): self.buffer = '' self.closed = False - self.what = what + self.what = what # server or client def send(self, data): return self.conn.send(data) @@ -251,7 +251,7 @@ class Connection(object): logger.debug('flushed %d bytes to %s' % (sent, self.what)) class Server(Connection): - """Established connection to destination server.""" + """Establish connection to destination server.""" def __init__(self, host, port): super(Server, self).__init__('server') @@ -473,22 +473,20 @@ class Proxy(multiprocessing.Process): self._access_log() logger.debug('Closing proxy for connection %r at address %r' % (self.client.conn, self.client.addr)) -class Http(object): - """HTTP server implementation. - - Listens on configured (host, port) and spawns a new process - for handling every accepted HTTP connection. Spawned process - takes care of proxying the HTTP request. - """ +class TCP(object): + """TCP server implementation.""" def __init__(self, hostname='127.0.0.1', port=8899, backlog=100): self.hostname = hostname self.port = port self.backlog = backlog + def handle(self, client): + raise NotImplementedError() + def run(self): try: - logger.info('Starting proxy server on port %d' % self.port) + logger.info('Starting server on port %d' % self.port) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.hostname, self.port)) @@ -497,33 +495,44 @@ class Http(object): conn, addr = self.socket.accept() logger.debug('Accepted connection %r at address %r' % (conn, addr)) client = Client(conn, addr) - proc = Proxy(client) - proc.daemon = True - proc.start() - logger.debug('Started process %r to handle connection %r' % (proc, conn)) + self.handle(client) except Exception as e: logger.exception('Exception while running the server %r' % e) finally: logger.info('Closing server socket') self.socket.close() +class HTTP(TCP): + """HTTP proxy server implementation. + + Spawns new process to proxy accepted client connection. + """ + + def handle(self, client): + proc = Proxy(client) + proc.daemon = True + proc.start() + logger.debug('Started process %r to handle connection %r' % (proc, client.conn)) + def main(): parser = argparse.ArgumentParser( description='proxy.py v%s' % __version__, epilog='Having difficulty using proxy.py? Report at: %s/issues/new' % __homepage__ ) + parser.add_argument('--hostname', default='127.0.0.1', help='Default: 127.0.0.1') parser.add_argument('--port', default='8899', help='Default: 8899') parser.add_argument('--log-level', default='INFO', help='DEBUG, INFO, WARNING, ERROR, CRITICAL') args = parser.parse_args() - hostname = args.hostname - port = int(args.port) logging.basicConfig(level=getattr(logging, args.log_level), format='%(asctime)s - %(process)d - %(message)s') + hostname = args.hostname + port = int(args.port) + try: - http = Http(hostname, port) - http.run() + proxy = HTTP(hostname, port) + proxy.run() except KeyboardInterrupt: pass diff --git a/tests.py b/tests.py new file mode 100644 index 00000000..32d7ec42 --- /dev/null +++ b/tests.py @@ -0,0 +1,327 @@ +import unittest +import proxy +from proxy import * + +class TestChunkParser(unittest.TestCase): + + def setUp(self): + self.parser = ChunkParser() + + def test_chunk_parse(self): + self.parser.parse(''.join([ + '4\r\n', + 'Wiki\r\n', + '5\r\n', + 'pedia\r\n', + 'E\r\n', + ' in\r\n\r\nchunks.\r\n', + '0\r\n', + '\r\n' + ])) + self.assertEqual(self.parser.chunk, '') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, 'Wikipedia in\r\n\r\nchunks.') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_COMPLETE) + +class TestHttpParser(unittest.TestCase): + + def setUp(self): + self.parser = HttpParser() + + def test_get_full_parse(self): + raw = CRLF.join([ + "GET %s HTTP/1.1", + "Host: %s", + CRLF + ]) + self.parser.parse(raw % ('https://example.com/path/dir/?a=b&c=d#p=q', 'example.com')) + self.assertEqual(self.parser.build_url(), '/path/dir/?a=b&c=d#p=q') + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "example.com") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertDictContainsSubset({'host':('Host', 'example.com')}, self.parser.headers) + self.assertEqual(raw % ('/path/dir/?a=b&c=d#p=q', 'example.com'), self.parser.build(del_headers=['host'], add_headers=[('Host', 'example.com')])) + + def test_build_url_none(self): + self.assertEqual(self.parser.build_url(), '/None') + + def test_line_rcvd_to_rcving_headers_state_change(self): + self.parser.parse("GET http://localhost HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_INITIALIZED) + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + def test_get_partial_parse1(self): + self.parser.parse(CRLF.join([ + "GET http://localhost:8080 HTTP/1.1" + ])) + self.assertEqual(self.parser.method, None) + self.assertEqual(self.parser.url, None) + self.assertEqual(self.parser.version, None) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_INITIALIZED) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, 8080) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse("Host: localhost:8080") + self.assertDictEqual(self.parser.headers, dict()) + self.assertEqual(self.parser.buffer, "Host: localhost:8080") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse(CRLF*2) + self.assertDictContainsSubset({'host':('Host', 'localhost:8080')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_get_partial_parse2(self): + self.parser.parse(CRLF.join([ + "GET http://localhost:8080 HTTP/1.1", + "Host: " + ])) + self.assertEqual(self.parser.method, "GET") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, 8080) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.buffer, "Host: ") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_LINE_RCVD) + + self.parser.parse("localhost:8080%s" % CRLF) + self.assertDictContainsSubset({'host': ('Host', 'localhost:8080')}, self.parser.headers) + self.assertEqual(self.parser.buffer, "") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse("Content-Type: text/plain%s" % CRLF) + self.assertEqual(self.parser.buffer, "") + self.assertDictContainsSubset({'content-type': ('Content-Type', 'text/plain')}, self.parser.headers) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + + def test_post_full_parse(self): + raw = CRLF.join([ + "POST %s HTTP/1.1", + "Host: localhost", + "Content-Length: 7", + "Content-Type: application/x-www-form-urlencoded%s" % CRLF, + "a=b&c=d" + ]) + self.parser.parse(raw % 'http://localhost') + self.assertEqual(self.parser.method, "POST") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertDictContainsSubset({'content-type': ('Content-Type', 'application/x-www-form-urlencoded')}, self.parser.headers) + self.assertDictContainsSubset({'content-length': ('Content-Length', '7')}, self.parser.headers) + self.assertEqual(self.parser.body, "a=b&c=d") + self.assertEqual(self.parser.buffer, "") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(len(self.parser.build()), len(raw % '/')) + + def test_post_partial_parse(self): + self.parser.parse(CRLF.join([ + "POST http://localhost HTTP/1.1", + "Host: localhost", + "Content-Length: 7", + "Content-Type: application/x-www-form-urlencoded" + ])) + self.assertEqual(self.parser.method, "POST") + self.assertEqual(self.parser.url.hostname, "localhost") + self.assertEqual(self.parser.url.port, None) + self.assertEqual(self.parser.version, "HTTP/1.1") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_HEADERS) + + self.parser.parse(CRLF) + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_HEADERS_COMPLETE) + + self.parser.parse("a=b") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_RCVING_BODY) + self.assertEqual(self.parser.body, "a=b") + self.assertEqual(self.parser.buffer, "") + + self.parser.parse("&c=d") + self.assertEqual(self.parser.state, HTTP_PARSER_STATE_COMPLETE) + self.assertEqual(self.parser.body, "a=b&c=d") + self.assertEqual(self.parser.buffer, "") + + def test_response_parse(self): + self.parser.type = HTTP_RESPONSE_PARSER + self.parser.parse(''.join([ + 'HTTP/1.1 301 Moved Permanently\r\n', + 'Location: http://www.google.com/\r\n', + 'Content-Type: text/html; charset=UTF-8\r\n', + 'Date: Wed, 22 May 2013 14:07:29 GMT\r\n', + 'Expires: Fri, 21 Jun 2013 14:07:29 GMT\r\n', + 'Cache-Control: public, max-age=2592000\r\n', + 'Server: gws\r\n', + 'Content-Length: 219\r\n', + 'X-XSS-Protection: 1; mode=block\r\n', + 'X-Frame-Options: SAMEORIGIN\r\n\r\n', + '
\n