diff --git a/proxy.py b/proxy.py index 16bc42ac..ae7baa91 100755 --- a/proxy.py +++ b/proxy.py @@ -79,19 +79,29 @@ class ChunkParser(object): def __init__(self): self.state = CHUNK_PARSER_STATE_WAITING_FOR_SIZE - self.body = b'' - self.chunk = b'' - self.size = None + self.body = b'' # Parsed chunks + self.chunk = b'' # Partial chunk received + self.size = None # Expected size of next following chunk def parse(self, data): more = True if len(data) > 0 else False - while more: more, data = self.process(data) + while more: + more, data = self.process(data) def process(self, data): if self.state == CHUNK_PARSER_STATE_WAITING_FOR_SIZE: + # Consume prior chunk in buffer + # in case chunk size without CRLF was received + data = self.chunk + data + self.chunk = b'' + # Extract following chunk data size line, data = HttpParser.split(data) - self.size = int(line, 16) - self.state = CHUNK_PARSER_STATE_WAITING_FOR_DATA + if not line: # CRLF not received + self.chunk = data + data = b'' + else: + self.size = int(line, 16) + self.state = CHUNK_PARSER_STATE_WAITING_FOR_DATA elif self.state == CHUNK_PARSER_STATE_WAITING_FOR_DATA: remaining = self.size - len(self.chunk) self.chunk += data[:remaining] @@ -388,8 +398,7 @@ class Proxy(threading.Thread): elif self.request.url: host, port = self.request.url.hostname, self.request.url.port if self.request.url.port else 80 else: - # TODO(abhinavsingh): Gracefully return invalid request in such cases - pass + raise Exception('Invalid request\n%s' % self.request.raw) self.server = Server(host, port) try: diff --git a/tests.py b/tests.py index 7c94ca5b..d570b612 100644 --- a/tests.py +++ b/tests.py @@ -24,6 +24,43 @@ class TestChunkParser(unittest.TestCase): self.assertEqual(self.parser.body, b'Wikipedia in\r\n\r\nchunks.') self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_COMPLETE) + def test_chunk_parse_issue_27(self): + self.parser.parse(b'3') + self.assertEqual(self.parser.chunk, b'3') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, b'') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_SIZE) + self.parser.parse(b'\r\n') + self.assertEqual(self.parser.chunk, b'') + self.assertEqual(self.parser.size, 3) + self.assertEqual(self.parser.body, b'') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_DATA) + self.parser.parse(b'abc') + self.assertEqual(self.parser.chunk, b'') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, b'abc') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_SIZE) + self.parser.parse(b'\r\n') + self.assertEqual(self.parser.chunk, b'') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, b'abc') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_SIZE) + self.parser.parse(b'4\r\n') + self.assertEqual(self.parser.chunk, b'') + self.assertEqual(self.parser.size, 4) + self.assertEqual(self.parser.body, b'abc') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_DATA) + self.parser.parse(b'defg\r\n0') + self.assertEqual(self.parser.chunk, b'0') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, b'abcdefg') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_WAITING_FOR_SIZE) + self.parser.parse(b'\r\n\r\n') + self.assertEqual(self.parser.chunk, b'') + self.assertEqual(self.parser.size, None) + self.assertEqual(self.parser.body, b'abcdefg') + self.assertEqual(self.parser.state, CHUNK_PARSER_STATE_COMPLETE) + class TestHttpParser(unittest.TestCase):