diff --git a/proxy.py b/proxy.py index 32bf76f4..bba62e0c 100755 --- a/proxy.py +++ b/proxy.py @@ -221,6 +221,12 @@ class HttpParser(object): elif self.state in (HttpParser.states.LINE_RCVD, HttpParser.states.RCVING_HEADERS): self.process_header(line) + # See `TestHttpParser.test_connect_request_without_host_header_request_parse` for details + if self.state == HttpParser.states.RCVING_HEADERS and \ + self.method == b'CONNECT' and \ + self.raw.endswith(CRLF * 2): + self.state = HttpParser.states.COMPLETE + if self.state == HttpParser.states.HEADERS_COMPLETE and \ self.type == HttpParser.types.REQUEST_PARSER and \ not self.method == b'POST' and \ @@ -275,12 +281,12 @@ class HttpParser(object): del_headers = [] for k in self.headers: if k not in del_headers: - req += self.build_header(self.headers[k][0], self.headers[k][1]) + req += self.build_header(self.headers[k][0], self.headers[k][1]) + CRLF if not add_headers: add_headers = [] for k in add_headers: - req += self.build_header(k[0], k[1]) + req += self.build_header(k[0], k[1]) + CRLF req += CRLF if self.body: @@ -290,7 +296,7 @@ class HttpParser(object): @staticmethod def build_header(k, v): - return k + b': ' + v + CRLF + return k + b': ' + v @staticmethod def split(data): diff --git a/tests.py b/tests.py index 9f55d2f2..5ca73df5 100644 --- a/tests.py +++ b/tests.py @@ -94,6 +94,17 @@ class TestHttpParser(unittest.TestCase): def setUp(self): self.parser = HttpParser(HttpParser.types.REQUEST_PARSER) + def test_build_header(self): + self.assertEqual(HttpParser.build_header(b'key', b'value'), b'key: value') + + def test_split(self): + self.assertEqual(HttpParser.split(b'CONNECT python.org:443 HTTP/1.0\r\n\r\n'), + (b'CONNECT python.org:443 HTTP/1.0', '\r\n')) + + def test_split_false_line(self): + self.assertEqual(HttpParser.split(b'CONNECT python.org:443 HTTP/1.0'), + (False, b'CONNECT python.org:443 HTTP/1.0')) + def test_get_full_parse(self): raw = CRLF.join([ b'GET %s HTTP/1.1', @@ -236,7 +247,7 @@ class TestHttpParser(unittest.TestCase): self.parser.parse(b'CONNECT pypi.org:443 HTTP/1.0\r\n\r\n') self.assertEqual(self.parser.method, b'CONNECT') self.assertEqual(self.parser.version, b'HTTP/1.0') - self.assertEqual(self.parser.state, HttpParser.states.RCVING_HEADERS) + self.assertEqual(self.parser.state, HttpParser.states.COMPLETE) def test_request_parse_without_content_length(self): """Case when incoming request doesn't contain a content-length header.