From 828ba0c2e79d8c54806a1c9eefb6007abc8aabc0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 6 Jul 2020 01:01:48 +0200 Subject: [PATCH] refactor h2 trailer support This allow's trailers without the initial Trailer header announcement. In HTTP/2 the stream ends with any frame containing END_SREAM. In the case of trailers, it is a final HEADERS frame after all the DATA frames. Therefore we do not need to explicitly check for the trailer announcement header, but can simply wait until the response message / stream has ended. --- examples/addons/http-trailers.py | 2 + mitmproxy/proxy/protocol/http2.py | 141 +++++++++++++++--------------- 2 files changed, 73 insertions(+), 70 deletions(-) diff --git a/examples/addons/http-trailers.py b/examples/addons/http-trailers.py index 77b9d1a40..d85965c13 100644 --- a/examples/addons/http-trailers.py +++ b/examples/addons/http-trailers.py @@ -10,10 +10,12 @@ body. from mitmproxy import http from mitmproxy.net.http import Headers + def request(flow: http.HTTPFlow): if flow.request.trailers: print("HTTP Trailers detected! Request contains:", flow.request.trailers) + def response(flow: http.HTTPFlow): if flow.response.trailers: print("HTTP Trailers detected! Response contains:", flow.response.trailers) diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index 602946f65..5da91ac24 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -180,22 +180,22 @@ class Http2Layer(base.Layer): headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers) self.streams[eid].timestamp_start = time.time() - self.streams[eid].no_body = (event.stream_ended is not None) + self.streams[eid].no_request_body = (event.stream_ended is not None) if event.priority_updated is not None: self.streams[eid].priority_exclusive = event.priority_updated.exclusive self.streams[eid].priority_depends_on = event.priority_updated.depends_on self.streams[eid].priority_weight = event.priority_updated.weight self.streams[eid].handled_priority_event = event.priority_updated self.streams[eid].start() - self.streams[eid].request_arrived.set() + self.streams[eid].request_message.arrived.set() return True def _handle_response_received(self, eid, event): headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) self.streams[eid].queued_data_length = 0 self.streams[eid].timestamp_start = time.time() - self.streams[eid].response_headers = headers - self.streams[eid].response_arrived.set() + self.streams[eid].response_message.headers = headers + self.streams[eid].response_message.arrived.set() return True def _handle_data_received(self, eid, event, source_conn): @@ -220,7 +220,7 @@ class Http2Layer(base.Layer): def _handle_stream_ended(self, eid): self.streams[eid].timestamp_end = time.time() - self.streams[eid].data_finished.set() + self.streams[eid].stream_ended.set() return True def _handle_stream_reset(self, eid, event, is_server, other_conn): @@ -236,9 +236,7 @@ class Http2Layer(base.Layer): def _handle_trailers(self, eid, event, is_server, other_conn): trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) - # TODO: support request trailers as well! - self.streams[eid].response_trailers = trailers - self.streams[eid].response_trailers_arrived.set() + self.streams[eid].trailers = trailers return True def _handle_remote_settings_changed(self, event, other_conn): @@ -285,8 +283,8 @@ class Http2Layer(base.Layer): self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid self.streams[event.pushed_stream_id].timestamp_end = time.time() - self.streams[event.pushed_stream_id].request_arrived.set() - self.streams[event.pushed_stream_id].request_data_finished.set() + self.streams[event.pushed_stream_id].request_message.arrived.set() + self.streams[event.pushed_stream_id].request_message.stream_ended.set() self.streams[event.pushed_stream_id].start() return True @@ -400,6 +398,16 @@ def detect_zombie_stream(func): # pragma: no cover class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread): + class Message: + def __init__(self, headers=None): + self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream + self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames + self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit + self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set + + self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received + self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received + def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None: super().__init__( ctx, name="Http2SingleStreamLayer-{}".format(stream_id) @@ -408,28 +416,15 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.zombie: Optional[float] = None self.client_stream_id: int = stream_id self.server_stream_id: Optional[int] = None - self.request_headers = request_headers - self.response_headers: Optional[mitmproxy.net.http.Headers] = None self.pushed = False self.timestamp_start: Optional[float] = None self.timestamp_end: Optional[float] = None - self.request_arrived = threading.Event() - self.request_data_queue: queue.Queue[bytes] = queue.Queue() - self.request_queued_data_length = 0 - self.request_data_finished = threading.Event() - self.request_trailers_arrived = threading.Event() - self.request_trailers = None + self.request_message = self.Message(request_headers) + self.response_message = self.Message() - self.response_arrived = threading.Event() - self.response_data_queue: queue.Queue[bytes] = queue.Queue() - self.response_queued_data_length = 0 - self.response_data_finished = threading.Event() - self.response_trailers_arrived = threading.Event() - self.response_trailers = None - - self.no_body = False + self.no_request_body = False self.priority_exclusive: bool self.priority_depends_on: Optional[int] = None @@ -439,12 +434,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr def kill(self): if not self.zombie: self.zombie = time.time() - self.request_data_finished.set() - self.request_arrived.set() - self.request_trailers_arrived.set() - self.response_arrived.set() - self.response_data_finished.set() - self.response_trailers_arrived.set() + self.request_message.stream_ended.set() + self.request_message.arrived.set() + self.response_message.arrived.set() + self.response_message.stream_ended.set() def connect(self): # pragma: no cover raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.") @@ -462,28 +455,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @property def data_queue(self): - if self.response_arrived.is_set(): - return self.response_data_queue + if self.response_message.arrived.is_set(): + return self.response_message.data_queue else: - return self.request_data_queue + return self.request_message.data_queue @property def queued_data_length(self): - if self.response_arrived.is_set(): - return self.response_queued_data_length + if self.response_message.arrived.is_set(): + return self.response_message.queued_data_length else: - return self.request_queued_data_length + return self.request_message.queued_data_length @queued_data_length.setter def queued_data_length(self, v): - self.request_queued_data_length = v + self.request_message.queued_data_length = v @property - def data_finished(self): - if self.response_arrived.is_set(): - return self.response_data_finished + def stream_ended(self): + # This indicates that all message headers, the full message body, and all trailers have been received + # https://tools.ietf.org/html/rfc7540#section-8.1 + if self.response_message.arrived.is_set(): + return self.response_message.stream_ended else: - return self.request_data_finished + return self.request_message.stream_ended + + @property + def trailers(self): + if self.response_message.arrived.is_set(): + return self.response_message.trailers + else: + return self.request_message.trailers + + @trailers.setter + def trailers(self, v): + if self.response_message.arrived.is_set(): + self.response_message.trailers = v + else: + self.request_message.trailers = v def raise_zombie(self, pre_command=None): # pragma: no cover connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED @@ -494,13 +503,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @detect_zombie_stream def read_request_headers(self, flow): - self.request_arrived.wait() + self.request_message.arrived.wait() self.raise_zombie() if self.pushed: flow.metadata['h2-pushed-stream'] = True - first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers) + first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers) return http.HTTPRequest( first_line_format, method, @@ -509,7 +518,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr port, path, b"HTTP/2.0", - self.request_headers, + self.request_message.headers, None, timestamp_start=self.timestamp_start, timestamp_end=self.timestamp_end, @@ -518,27 +527,23 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @detect_zombie_stream def read_request_body(self, request): if not request.stream: - self.request_data_finished.wait() + self.request_message.stream_ended.wait() while True: try: - yield self.request_data_queue.get(timeout=0.1) + yield self.request_message.data_queue.get(timeout=0.1) except queue.Empty: # pragma: no cover pass - if self.request_data_finished.is_set(): + if self.request_message.stream_ended.is_set(): self.raise_zombie() - while self.request_data_queue.qsize() > 0: - yield self.request_data_queue.get() + while self.request_message.data_queue.qsize() > 0: + yield self.request_message.data_queue.get() break self.raise_zombie() @detect_zombie_stream def read_request_trailers(self, request): - if "trailer" in request.headers: - self.request_trailers_arrived.wait() - self.raise_zombie() - return self.request_trailers - return None + return self.request_message.trailers @detect_zombie_stream def send_request_headers(self, request): @@ -589,7 +594,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.raise_zombie, self.server_stream_id, headers, - end_stream=self.no_body, + end_stream=self.no_request_body, priority_exclusive=priority_exclusive, priority_depends_on=priority_depends_on, priority_weight=priority_weight, @@ -606,7 +611,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr # nothing to do here return - if not self.no_body: + if not self.no_request_body: self.connections[self.server_conn].safe_send_body( self.raise_zombie, self.server_stream_id, @@ -625,12 +630,12 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @detect_zombie_stream def read_response_headers(self): - self.response_arrived.wait() + self.response_message.arrived.wait() self.raise_zombie() - status_code = int(self.response_headers.get(':status', 502)) - headers = self.response_headers.copy() + status_code = int(self.response_message.headers.get(':status', 502)) + headers = self.response_message.headers.copy() headers.pop(":status", None) return http.HTTPResponse( @@ -647,23 +652,19 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr def read_response_body(self, request, response): while True: try: - yield self.response_data_queue.get(timeout=0.1) + yield self.response_message.data_queue.get(timeout=0.1) except queue.Empty: # pragma: no cover pass - if self.response_data_finished.is_set(): + if self.response_message.stream_ended.is_set(): self.raise_zombie() - while self.response_data_queue.qsize() > 0: - yield self.response_data_queue.get() + while self.response_message.data_queue.qsize() > 0: + yield self.response_message.data_queue.get() break self.raise_zombie() @detect_zombie_stream def read_response_trailers(self, request, response): - if "trailer" in response.headers: - self.response_trailers_arrived.wait() - self.raise_zombie() - return self.response_trailers - return None + return self.response_message.trailers @detect_zombie_stream def send_response_headers(self, response):