From 39589404209a9980c0a07137f367f70c103e3113 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 5 Nov 2016 10:59:56 +1300 Subject: [PATCH] Test failure during 100-continue Also: - Remove duplicate and unused code - Tighten scope of HttpReadDisconnect handler - we only want to ignore this for the initial read, not for the entire block that includes things like the expect handling. --- mitmproxy/proxy/modes/socks_proxy.py | 3 -- mitmproxy/proxy/protocol/http.py | 21 ++++++-------- test/mitmproxy/test_eventsequence.py | 42 ++++++++++++++++++++-------- 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/mitmproxy/proxy/modes/socks_proxy.py b/mitmproxy/proxy/modes/socks_proxy.py index 042580373..3121b731a 100644 --- a/mitmproxy/proxy/modes/socks_proxy.py +++ b/mitmproxy/proxy/modes/socks_proxy.py @@ -5,9 +5,6 @@ from mitmproxy.net import socks class Socks5Proxy(protocol.Layer, protocol.ServerConnectionMixin): - def __init__(self, ctx): - super().__init__(ctx) - def __call__(self): try: # Parse Client Greeting diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 4caaf1e30..9fe83ff61 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -160,27 +160,24 @@ class HttpLayer(base.Layer): def _process_flow(self, f): try: - request = self.read_request_headers(f) + try: + request = self.read_request_headers(f) + except exceptions.HttpReadDisconnect: + # don't throw an error for disconnects that happen before/between requests. + return False + f.request = request self.channel.ask("requestheaders", f) - request.data.content = b"".join(self.read_request_body(request)) - request.timestamp_end = time.time() if request.headers.get("expect", "").lower() == "100-continue": # TODO: We may have to use send_response_headers for HTTP2 here. self.send_response(http.expect_continue_response) request.headers.pop("expect") - request.content = b"".join(self.read_request_body(request)) - request.timestamp_end = time.time() + + request.data.content = b"".join(self.read_request_body(request)) + request.timestamp_end = time.time() validate_request_form(self.mode, request) - - if self.mode == "regular" and request.first_line_format == "absolute": - request.first_line_format = "relative" - - except exceptions.HttpReadDisconnect: - # don't throw an error for disconnects that happen before/between requests. - return False except exceptions.HttpException as e: # We optimistically guess there might be an HTTP client on the # other end diff --git a/test/mitmproxy/test_eventsequence.py b/test/mitmproxy/test_eventsequence.py index 7fdbce1b9..31c57e82d 100644 --- a/test/mitmproxy/test_eventsequence.py +++ b/test/mitmproxy/test_eventsequence.py @@ -3,15 +3,17 @@ import contextlib from . import tservers -class EAddon: - def __init__(self, handlers): +class Eventer: + def __init__(self, **handlers): self.failure = None + self.called = [] self.handlers = handlers - for i in events.Events: + for i in events.Events - {"tick"}: def mkprox(): evt = i def prox(*args, **kwargs): + self.called.append(evt) if evt in self.handlers: try: handlers[evt](*args, **kwargs) @@ -26,23 +28,41 @@ class EAddon: class SequenceTester: @contextlib.contextmanager - def events(self, **kwargs): - m = EAddon(kwargs) - self.master.addons.add(m) + def addon(self, addon): + self.master.addons.add(addon) yield - self.master.addons.remove(m) - if m.failure: - raise m.failure + self.master.addons.remove(addon) + if addon.failure: + raise addon.failure class TestBasic(tservers.HTTPProxyTest, SequenceTester): def test_requestheaders(self): - def req(f): + def hdrs(f): assert f.request.headers assert not f.request.content - with self.events(requestheaders=req): + def req(f): + assert f.request.headers + assert f.request.content + + with self.addon(Eventer(requestheaders=hdrs, request=req)): p = self.pathoc() with p.connect(): assert p.request("get:'%s/p/200':b@10" % self.server.urlbase).status_code == 200 + + def test_100_continue_fail(self): + e = Eventer() + with self.addon(e): + p = self.pathoc() + with p.connect(): + p.request( + """ + get:'%s/p/200' + h'expect'='100-continue' + h'content-length'='1000' + da + """ % self.server.urlbase + ) + assert e.called[-1] == "requestheaders"