Do not store raw response bytes for CONNECT requests.

Fixes #67 and addresses #66 too.
This commit is contained in:
Abhinav Singh 2019-09-13 11:03:26 -07:00
parent 3e92faba32
commit 80c73a4798
2 changed files with 31 additions and 15 deletions

View File

@ -18,7 +18,6 @@ import logging
import multiprocessing
import os
import queue
import select
import socket
import sys
import threading
@ -26,6 +25,8 @@ from collections import namedtuple
from typing import Dict, List, Tuple
from urllib import parse as urlparse
import select
if os.name != 'nt':
import resource
@ -382,10 +383,15 @@ class HttpParser:
self.type: HttpParser.types = parser_type
self.state: HttpParser.states = HttpParser.states.INITIALIZED
# Raw bytes as passed to parse(raw) method and its total size
self.bytes: bytes = b''
self.total_size: int = 0
# Buffer to hold unprocessed bytes
self.buffer: bytes = b''
self.headers: Dict[bytes, Tuple[bytes, bytes]] = dict()
# Can simply be b'', then set type as bytes?
self.body = None
@ -415,12 +421,14 @@ class HttpParser:
raise Exception('Invalid request\n%s' % self.bytes)
def is_chunked_encoded_response(self):
return self.type == HttpParser.types.RESPONSE_PARSER and \
b'transfer-encoding' in self.headers and \
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
return self.type == HttpParser.types.RESPONSE_PARSER and b'transfer-encoding' in self.headers and \
self.headers[b'transfer-encoding'][1].lower() == b'chunked'
def parse(self, raw):
self.bytes += raw
self.total_size += len(raw)
# Prepend past buffer
raw = self.buffer + raw
self.buffer = b''
@ -829,9 +837,7 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
if not self.request.method == b'CONNECT':
self.response.parse(raw)
else:
# Only purpose of increasing memory footprint is to print response length in access log
# Not worth it? Optimize to only persist lengths?
self.response.bytes += raw
self.response.total_size += len(raw)
# queue raw data for client
self.client.queue(raw)
@ -890,12 +896,12 @@ class HttpProxyPlugin(HttpProtocolBasePlugin):
logger.info(
'%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1],
text_(self.request.method), text_(host),
text_(port), len(self.response.bytes)))
text_(port), self.response.total_size))
elif self.request.method:
logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % (
self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port,
text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason),
len(self.response.bytes)))
self.response.total_size))
def authenticate(self, headers):
if self.config.auth_code:
@ -1242,7 +1248,7 @@ def main(args) -> None:
pac_file=args.pac_file,
pac_file_url_path=args.pac_file_url_path,
disable_headers=[header.lower() for header in args.disable_headers.split(COMMA) if
header.strip() is not ''])
header.strip() != ''])
if config.pac_file is not None:
args.enable_web_server = True

View File

@ -287,7 +287,9 @@ class TestHttpParser(unittest.TestCase):
b'Host: %s',
proxy.CRLF
])
self.parser.parse(raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', b'example.com'))
pkt = raw % (b'https://example.com/path/dir/?a=b&c=d#p=q', b'example.com')
self.parser.parse(pkt)
self.assertEqual(self.parser.total_size, len(pkt))
self.assertEqual(self.parser.build_url(), b'/path/dir/?a=b&c=d#p=q')
self.assertEqual(self.parser.method, b'GET')
self.assertEqual(self.parser.url.hostname, b'example.com')
@ -303,33 +305,41 @@ class TestHttpParser(unittest.TestCase):
self.assertEqual(self.parser.build_url(), b'/None')
def test_line_rcvd_to_rcving_headers_state_change(self):
self.parser.parse(b'GET http://localhost HTTP/1.1')
pkt = b'GET http://localhost HTTP/1.1'
self.parser.parse(pkt)
self.assertEqual(self.parser.total_size, len(pkt))
self.assert_state_change_with_crlf(proxy.HttpParser.states.INITIALIZED,
proxy.HttpParser.states.LINE_RCVD,
proxy.HttpParser.states.RCVING_HEADERS)
def test_get_partial_parse1(self):
self.parser.parse(proxy.CRLF.join([
pkt = proxy.CRLF.join([
b'GET http://localhost:8080 HTTP/1.1'
]))
])
self.parser.parse(pkt)
self.assertEqual(self.parser.total_size, len(pkt))
self.assertEqual(self.parser.method, None)
self.assertEqual(self.parser.url, None)
self.assertEqual(self.parser.version, None)
self.assertEqual(self.parser.state, proxy.HttpParser.states.INITIALIZED)
self.parser.parse(proxy.CRLF)
self.assertEqual(self.parser.total_size, len(pkt) + len(proxy.CRLF))
self.assertEqual(self.parser.method, b'GET')
self.assertEqual(self.parser.url.hostname, b'localhost')
self.assertEqual(self.parser.url.port, 8080)
self.assertEqual(self.parser.version, b'HTTP/1.1')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(b'Host: localhost:8080')
host_hdr = b'Host: localhost:8080'
self.parser.parse(host_hdr)
self.assertEqual(self.parser.total_size, len(pkt) + len(proxy.CRLF) + len(host_hdr))
self.assertDictEqual(self.parser.headers, dict())
self.assertEqual(self.parser.buffer, b'Host: localhost:8080')
self.assertEqual(self.parser.state, proxy.HttpParser.states.LINE_RCVD)
self.parser.parse(proxy.CRLF * 2)
self.assertEqual(self.parser.total_size, len(pkt) + (3 * len(proxy.CRLF)) + len(host_hdr))
self.assertDictContainsSubset({b'host': (b'Host', b'localhost:8080')}, self.parser.headers)
self.assertEqual(self.parser.state, proxy.HttpParser.states.COMPLETE)