refactor h2 trailer support

This allow's trailers without the initial Trailer header announcement. In HTTP/2 the stream ends with any frame containing END_SREAM. In the case of trailers, it is a final HEADERS frame after all the DATA frames. Therefore we do not need to explicitly check for the trailer announcement header, but can simply wait until the response message / stream has ended.
This commit is contained in:
Thomas Kriechbaumer 2020-07-06 01:01:48 +02:00
parent 288ce65d73
commit 828ba0c2e7
2 changed files with 73 additions and 70 deletions

View File

@ -10,10 +10,12 @@ body.
from mitmproxy import http
from mitmproxy.net.http import Headers
def request(flow: http.HTTPFlow):
if flow.request.trailers:
print("HTTP Trailers detected! Request contains:", flow.request.trailers)
def response(flow: http.HTTPFlow):
if flow.response.trailers:
print("HTTP Trailers detected! Response contains:", flow.response.trailers)

View File

@ -180,22 +180,22 @@ class Http2Layer(base.Layer):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
self.streams[eid].timestamp_start = time.time()
self.streams[eid].no_body = (event.stream_ended is not None)
self.streams[eid].no_request_body = (event.stream_ended is not None)
if event.priority_updated is not None:
self.streams[eid].priority_exclusive = event.priority_updated.exclusive
self.streams[eid].priority_depends_on = event.priority_updated.depends_on
self.streams[eid].priority_weight = event.priority_updated.weight
self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start()
self.streams[eid].request_arrived.set()
self.streams[eid].request_message.arrived.set()
return True
def _handle_response_received(self, eid, event):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].queued_data_length = 0
self.streams[eid].timestamp_start = time.time()
self.streams[eid].response_headers = headers
self.streams[eid].response_arrived.set()
self.streams[eid].response_message.headers = headers
self.streams[eid].response_message.arrived.set()
return True
def _handle_data_received(self, eid, event, source_conn):
@ -220,7 +220,7 @@ class Http2Layer(base.Layer):
def _handle_stream_ended(self, eid):
self.streams[eid].timestamp_end = time.time()
self.streams[eid].data_finished.set()
self.streams[eid].stream_ended.set()
return True
def _handle_stream_reset(self, eid, event, is_server, other_conn):
@ -236,9 +236,7 @@ class Http2Layer(base.Layer):
def _handle_trailers(self, eid, event, is_server, other_conn):
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
# TODO: support request trailers as well!
self.streams[eid].response_trailers = trailers
self.streams[eid].response_trailers_arrived.set()
self.streams[eid].trailers = trailers
return True
def _handle_remote_settings_changed(self, event, other_conn):
@ -285,8 +283,8 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].request_arrived.set()
self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].request_message.arrived.set()
self.streams[event.pushed_stream_id].request_message.stream_ended.set()
self.streams[event.pushed_stream_id].start()
return True
@ -400,6 +398,16 @@ def detect_zombie_stream(func): # pragma: no cover
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
class Message:
def __init__(self, headers=None):
self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream
self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames
self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit
self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set
self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received
self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
super().__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
@ -408,28 +416,15 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.zombie: Optional[float] = None
self.client_stream_id: int = stream_id
self.server_stream_id: Optional[int] = None
self.request_headers = request_headers
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
self.pushed = False
self.timestamp_start: Optional[float] = None
self.timestamp_end: Optional[float] = None
self.request_arrived = threading.Event()
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
self.request_trailers_arrived = threading.Event()
self.request_trailers = None
self.request_message = self.Message(request_headers)
self.response_message = self.Message()
self.response_arrived = threading.Event()
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
self.response_trailers_arrived = threading.Event()
self.response_trailers = None
self.no_body = False
self.no_request_body = False
self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None
@ -439,12 +434,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def kill(self):
if not self.zombie:
self.zombie = time.time()
self.request_data_finished.set()
self.request_arrived.set()
self.request_trailers_arrived.set()
self.response_arrived.set()
self.response_data_finished.set()
self.response_trailers_arrived.set()
self.request_message.stream_ended.set()
self.request_message.arrived.set()
self.response_message.arrived.set()
self.response_message.stream_ended.set()
def connect(self): # pragma: no cover
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
@ -462,28 +455,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@property
def data_queue(self):
if self.response_arrived.is_set():
return self.response_data_queue
if self.response_message.arrived.is_set():
return self.response_message.data_queue
else:
return self.request_data_queue
return self.request_message.data_queue
@property
def queued_data_length(self):
if self.response_arrived.is_set():
return self.response_queued_data_length
if self.response_message.arrived.is_set():
return self.response_message.queued_data_length
else:
return self.request_queued_data_length
return self.request_message.queued_data_length
@queued_data_length.setter
def queued_data_length(self, v):
self.request_queued_data_length = v
self.request_message.queued_data_length = v
@property
def data_finished(self):
if self.response_arrived.is_set():
return self.response_data_finished
def stream_ended(self):
# This indicates that all message headers, the full message body, and all trailers have been received
# https://tools.ietf.org/html/rfc7540#section-8.1
if self.response_message.arrived.is_set():
return self.response_message.stream_ended
else:
return self.request_data_finished
return self.request_message.stream_ended
@property
def trailers(self):
if self.response_message.arrived.is_set():
return self.response_message.trailers
else:
return self.request_message.trailers
@trailers.setter
def trailers(self, v):
if self.response_message.arrived.is_set():
self.response_message.trailers = v
else:
self.request_message.trailers = v
def raise_zombie(self, pre_command=None): # pragma: no cover
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
@ -494,13 +503,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_headers(self, flow):
self.request_arrived.wait()
self.request_message.arrived.wait()
self.raise_zombie()
if self.pushed:
flow.metadata['h2-pushed-stream'] = True
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
return http.HTTPRequest(
first_line_format,
method,
@ -509,7 +518,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
port,
path,
b"HTTP/2.0",
self.request_headers,
self.request_message.headers,
None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
@ -518,27 +527,23 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_body(self, request):
if not request.stream:
self.request_data_finished.wait()
self.request_message.stream_ended.wait()
while True:
try:
yield self.request_data_queue.get(timeout=0.1)
yield self.request_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover
pass
if self.request_data_finished.is_set():
if self.request_message.stream_ended.is_set():
self.raise_zombie()
while self.request_data_queue.qsize() > 0:
yield self.request_data_queue.get()
while self.request_message.data_queue.qsize() > 0:
yield self.request_message.data_queue.get()
break
self.raise_zombie()
@detect_zombie_stream
def read_request_trailers(self, request):
if "trailer" in request.headers:
self.request_trailers_arrived.wait()
self.raise_zombie()
return self.request_trailers
return None
return self.request_message.trailers
@detect_zombie_stream
def send_request_headers(self, request):
@ -589,7 +594,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie,
self.server_stream_id,
headers,
end_stream=self.no_body,
end_stream=self.no_request_body,
priority_exclusive=priority_exclusive,
priority_depends_on=priority_depends_on,
priority_weight=priority_weight,
@ -606,7 +611,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
# nothing to do here
return
if not self.no_body:
if not self.no_request_body:
self.connections[self.server_conn].safe_send_body(
self.raise_zombie,
self.server_stream_id,
@ -625,12 +630,12 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
self.response_message.arrived.wait()
self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
status_code = int(self.response_message.headers.get(':status', 502))
headers = self.response_message.headers.copy()
headers.pop(":status", None)
return http.HTTPResponse(
@ -647,23 +652,19 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def read_response_body(self, request, response):
while True:
try:
yield self.response_data_queue.get(timeout=0.1)
yield self.response_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover
pass
if self.response_data_finished.is_set():
if self.response_message.stream_ended.is_set():
self.raise_zombie()
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
while self.response_message.data_queue.qsize() > 0:
yield self.response_message.data_queue.get()
break
self.raise_zombie()
@detect_zombie_stream
def read_response_trailers(self, request, response):
if "trailer" in response.headers:
self.response_trailers_arrived.wait()
self.raise_zombie()
return self.response_trailers
return None
return self.response_message.trailers
@detect_zombie_stream
def send_response_headers(self, response):