Add the http_connect event for HTTP CONNECT requests
This commit is contained in:
parent
bc01a146b0
commit
38f8d9e541
|
@ -13,6 +13,7 @@ Events = frozenset([
|
|||
"tcp_error",
|
||||
"tcp_end",
|
||||
|
||||
"http_connect",
|
||||
"request",
|
||||
"requestheaders",
|
||||
"response",
|
||||
|
|
|
@ -255,6 +255,10 @@ class Master:
|
|||
def next_layer(self, top_layer):
|
||||
pass
|
||||
|
||||
@controller.handler
|
||||
def http_connect(self, f):
|
||||
pass
|
||||
|
||||
@controller.handler
|
||||
def error(self, f):
|
||||
pass
|
||||
|
|
|
@ -176,6 +176,30 @@ class HttpLayer(base.Layer):
|
|||
# don't throw an error for disconnects that happen before/between requests.
|
||||
return False
|
||||
|
||||
# Regular Proxy Mode: Handle CONNECT
|
||||
if self.mode is HTTPMode.regular and request.first_line_format == "authority":
|
||||
self.connect_request = True
|
||||
# The standards are silent on what we should do with a CONNECT
|
||||
# request body, so although it's not common, it's allowed.
|
||||
request.data.content = b"".join(self.read_request_body(request))
|
||||
request.timestamp_end = time.time()
|
||||
|
||||
self.channel.ask("http_connect", f)
|
||||
|
||||
try:
|
||||
self.set_server((request.host, request.port))
|
||||
except (exceptions.ProtocolException, exceptions.NetlibException) as e:
|
||||
# HTTPS tasting means that ordinary errors like resolution and
|
||||
# connection errors can happen here.
|
||||
self.send_error_response(502, repr(e))
|
||||
f.error = flow.Error(str(e))
|
||||
self.channel.ask("error", f)
|
||||
return False
|
||||
self.send_response(http.make_connect_response(request.data.http_version))
|
||||
layer = self.ctx.next_layer(self)
|
||||
layer()
|
||||
return False
|
||||
|
||||
f.request = request
|
||||
self.channel.ask("requestheaders", f)
|
||||
|
||||
|
@ -207,23 +231,6 @@ class HttpLayer(base.Layer):
|
|||
|
||||
f.request = request
|
||||
|
||||
try:
|
||||
# Regular Proxy Mode: Handle CONNECT
|
||||
if self.mode is HTTPMode.regular and request.first_line_format == "authority":
|
||||
self.connect_request = True
|
||||
self.set_server((request.host, request.port))
|
||||
self.send_response(http.make_connect_response(request.data.http_version))
|
||||
layer = self.ctx.next_layer(self)
|
||||
layer()
|
||||
return False
|
||||
except (exceptions.ProtocolException, exceptions.NetlibException) as e:
|
||||
# HTTPS tasting means that ordinary errors like resolution and
|
||||
# connection errors can happen here.
|
||||
self.send_error_response(502, repr(e))
|
||||
f.error = flow.Error(str(e))
|
||||
self.channel.ask("error", f)
|
||||
return False
|
||||
|
||||
# update host header in reverse proxy mode
|
||||
if self.config.options.mode == "reverse":
|
||||
f.request.headers["Host"] = self.config.upstream_server.address.host
|
||||
|
|
|
@ -37,6 +37,8 @@ class SequenceTester:
|
|||
|
||||
|
||||
class TestBasic(tservers.HTTPProxyTest, SequenceTester):
|
||||
ssl = True
|
||||
|
||||
def test_requestheaders(self):
|
||||
|
||||
def hdrs(f):
|
||||
|
@ -50,7 +52,7 @@ class TestBasic(tservers.HTTPProxyTest, SequenceTester):
|
|||
with self.addon(Eventer(requestheaders=hdrs, request=req)):
|
||||
p = self.pathoc()
|
||||
with p.connect():
|
||||
assert p.request("get:'%s/p/200':b@10" % self.server.urlbase).status_code == 200
|
||||
assert p.request("get:'/p/200':b@10").status_code == 200
|
||||
|
||||
def test_100_continue_fail(self):
|
||||
e = Eventer()
|
||||
|
@ -59,10 +61,20 @@ class TestBasic(tservers.HTTPProxyTest, SequenceTester):
|
|||
with p.connect():
|
||||
p.request(
|
||||
"""
|
||||
get:'%s/p/200'
|
||||
get:'/p/200'
|
||||
h'expect'='100-continue'
|
||||
h'content-length'='1000'
|
||||
da
|
||||
""" % self.server.urlbase
|
||||
"""
|
||||
)
|
||||
assert e.called[-1] == "requestheaders"
|
||||
|
||||
def test_connect(self):
|
||||
e = Eventer()
|
||||
with self.addon(e):
|
||||
p = self.pathoc()
|
||||
with p.connect():
|
||||
p.request("get:'/p/200:b@1'")
|
||||
assert "http_connect" in e.called
|
||||
assert e.called.count("requestheaders") == 1
|
||||
assert e.called.count("request") == 1
|
||||
|
|
Loading…
Reference in New Issue