Test failure during 100-continue

Also:

- Remove duplicate and unused code
- Tighten scope of HttpReadDisconnect handler - we only want to ignore this for
the initial read, not for the entire block that includes things like the expect
handling.
This commit is contained in:
Aldo Cortesi 2016-11-05 10:59:56 +13:00
parent 82ac7d05a6
commit 3958940420
3 changed files with 40 additions and 26 deletions

View File

@ -5,9 +5,6 @@ from mitmproxy.net import socks
class Socks5Proxy(protocol.Layer, protocol.ServerConnectionMixin): class Socks5Proxy(protocol.Layer, protocol.ServerConnectionMixin):
def __init__(self, ctx):
super().__init__(ctx)
def __call__(self): def __call__(self):
try: try:
# Parse Client Greeting # Parse Client Greeting

View File

@ -160,27 +160,24 @@ class HttpLayer(base.Layer):
def _process_flow(self, f): def _process_flow(self, f):
try: try:
request = self.read_request_headers(f) try:
request = self.read_request_headers(f)
except exceptions.HttpReadDisconnect:
# don't throw an error for disconnects that happen before/between requests.
return False
f.request = request f.request = request
self.channel.ask("requestheaders", f) self.channel.ask("requestheaders", f)
request.data.content = b"".join(self.read_request_body(request))
request.timestamp_end = time.time()
if request.headers.get("expect", "").lower() == "100-continue": if request.headers.get("expect", "").lower() == "100-continue":
# TODO: We may have to use send_response_headers for HTTP2 here. # TODO: We may have to use send_response_headers for HTTP2 here.
self.send_response(http.expect_continue_response) self.send_response(http.expect_continue_response)
request.headers.pop("expect") request.headers.pop("expect")
request.content = b"".join(self.read_request_body(request))
request.timestamp_end = time.time() request.data.content = b"".join(self.read_request_body(request))
request.timestamp_end = time.time()
validate_request_form(self.mode, request) validate_request_form(self.mode, request)
if self.mode == "regular" and request.first_line_format == "absolute":
request.first_line_format = "relative"
except exceptions.HttpReadDisconnect:
# don't throw an error for disconnects that happen before/between requests.
return False
except exceptions.HttpException as e: except exceptions.HttpException as e:
# We optimistically guess there might be an HTTP client on the # We optimistically guess there might be an HTTP client on the
# other end # other end

View File

@ -3,15 +3,17 @@ import contextlib
from . import tservers from . import tservers
class EAddon: class Eventer:
def __init__(self, handlers): def __init__(self, **handlers):
self.failure = None self.failure = None
self.called = []
self.handlers = handlers self.handlers = handlers
for i in events.Events: for i in events.Events - {"tick"}:
def mkprox(): def mkprox():
evt = i evt = i
def prox(*args, **kwargs): def prox(*args, **kwargs):
self.called.append(evt)
if evt in self.handlers: if evt in self.handlers:
try: try:
handlers[evt](*args, **kwargs) handlers[evt](*args, **kwargs)
@ -26,23 +28,41 @@ class EAddon:
class SequenceTester: class SequenceTester:
@contextlib.contextmanager @contextlib.contextmanager
def events(self, **kwargs): def addon(self, addon):
m = EAddon(kwargs) self.master.addons.add(addon)
self.master.addons.add(m)
yield yield
self.master.addons.remove(m) self.master.addons.remove(addon)
if m.failure: if addon.failure:
raise m.failure raise addon.failure
class TestBasic(tservers.HTTPProxyTest, SequenceTester): class TestBasic(tservers.HTTPProxyTest, SequenceTester):
def test_requestheaders(self): def test_requestheaders(self):
def req(f): def hdrs(f):
assert f.request.headers assert f.request.headers
assert not f.request.content assert not f.request.content
with self.events(requestheaders=req): def req(f):
assert f.request.headers
assert f.request.content
with self.addon(Eventer(requestheaders=hdrs, request=req)):
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
assert p.request("get:'%s/p/200':b@10" % self.server.urlbase).status_code == 200 assert p.request("get:'%s/p/200':b@10" % self.server.urlbase).status_code == 200
def test_100_continue_fail(self):
e = Eventer()
with self.addon(e):
p = self.pathoc()
with p.connect():
p.request(
"""
get:'%s/p/200'
h'expect'='100-continue'
h'content-length'='1000'
da
""" % self.server.urlbase
)
assert e.called[-1] == "requestheaders"