diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 27c2a6642..ee66393fa 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -130,6 +130,7 @@ class Http2Layer(base.Layer): [repr(event)] ) + eid = None if hasattr(event, 'stream_id'): if is_server and event.stream_id % 2 == 1: eid = self.server_to_client_stream_ids[event.stream_id] @@ -137,83 +138,124 @@ class Http2Layer(base.Layer): eid = event.stream_id if isinstance(event, events.RequestReceived): - headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) - self.streams[eid].timestamp_start = time.time() - self.streams[eid].no_body = (event.stream_ended is not None) - if event.priority_updated is not None: - self.streams[eid].priority_exclusive = event.priority_updated.exclusive - self.streams[eid].priority_depends_on = event.priority_updated.depends_on - self.streams[eid].priority_weight = event.priority_updated.weight - self.streams[eid].handled_priority_event = event.priority_updated - self.streams[eid].start() + return self._handle_request_received(eid, event) elif isinstance(event, events.ResponseReceived): - headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid].queued_data_length = 0 - self.streams[eid].timestamp_start = time.time() - self.streams[eid].response_headers = headers - self.streams[eid].response_arrived.set() + return self._handle_response_received(eid, event) elif isinstance(event, events.DataReceived): - if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit: - self.streams[eid].zombie = time.time() - source_conn.h2.safe_reset_stream(event.stream_id, 0x7) - self.log("HTTP body too large. Limit is {}.".format(self.config.body_size_limit), "info") - else: - self.streams[eid].data_queue.put(event.data) - self.streams[eid].queued_data_length += len(event.data) - source_conn.h2.safe_increment_flow_control(event.stream_id, event.flow_controlled_length) + return self._handle_data_received(eid, event, source_conn) elif isinstance(event, events.StreamEnded): - self.streams[eid].timestamp_end = time.time() - self.streams[eid].data_finished.set() + return self._handle_stream_ended(eid) elif isinstance(event, events.StreamReset): - self.streams[eid].zombie = time.time() - if eid in self.streams and event.error_code == 0x8: - if is_server: - other_stream_id = self.streams[eid].client_stream_id - else: - other_stream_id = self.streams[eid].server_stream_id - if other_stream_id is not None: - other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + return self._handle_stream_reset(eid, event, is_server, other_conn) elif isinstance(event, events.RemoteSettingsChanged): - new_settings = dict([(id, cs.new_value) for (id, cs) in six.iteritems(event.changed_settings)]) - other_conn.h2.safe_update_settings(new_settings) + return self._handle_remote_settings_changed(event, other_conn) elif isinstance(event, events.ConnectionTerminated): - if event.error_code == h2.errors.NO_ERROR: - # Do not immediately terminate the other connection. - # Some streams might be still sending data to the client. - return False - else: - # Something terrible has happened - kill everything! - self.client_conn.h2.close_connection( - error_code=event.error_code, - last_stream_id=event.last_stream_id, - additional_data=event.additional_data - ) - self.client_conn.send(self.client_conn.h2.data_to_send()) - self._kill_all_streams() - return False + return self._handle_connection_terminated(event) elif isinstance(event, events.PushedStreamReceived): - # pushed stream ids should be unique and not dependent on race conditions - # only the parent stream id must be looked up first - parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] - with self.client_conn.h2.lock: - self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) - self.client_conn.send(self.client_conn.h2.data_to_send()) - - headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) - self.streams[event.pushed_stream_id].timestamp_start = time.time() - self.streams[event.pushed_stream_id].pushed = True - self.streams[event.pushed_stream_id].parent_stream_id = parent_eid - self.streams[event.pushed_stream_id].timestamp_end = time.time() - self.streams[event.pushed_stream_id].request_data_finished.set() - self.streams[event.pushed_stream_id].start() + return self._handle_pushed_stream_received(event) elif isinstance(event, events.PriorityUpdated): - if eid in self.streams and self.streams[eid].handled_priority_event is event: - # this event was already handled during stream creation - # HeadersFrame + Priority information as RequestReceived - return True + return self._handle_priority_updated(eid, event) + elif isinstance(event, events.TrailersReceived): + raise NotImplementedError('TrailersReceived not implemented') + # fail-safe for unhandled events + return True + + def _handle_request_received(self, eid, event): + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid].timestamp_start = time.time() + self.streams[eid].no_body = (event.stream_ended is not None) + if event.priority_updated is not None: + self.streams[eid].priority_exclusive = event.priority_updated.exclusive + self.streams[eid].priority_depends_on = event.priority_updated.depends_on + self.streams[eid].priority_weight = event.priority_updated.weight + self.streams[eid].handled_priority_event = event.priority_updated + self.streams[eid].start() + return True + + def _handle_response_received(self, eid, event): + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[eid].queued_data_length = 0 + self.streams[eid].timestamp_start = time.time() + self.streams[eid].response_headers = headers + self.streams[eid].response_arrived.set() + return True + + def _handle_data_received(self, eid, event, source_conn): + if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit: + self.streams[eid].zombie = time.time() + source_conn.h2.safe_reset_stream(event.stream_id, h2.errors.REFUSED_STREAM) + self.log("HTTP body too large. Limit is {}.".format(self.config.body_size_limit), "info") + else: + self.streams[eid].data_queue.put(event.data) + self.streams[eid].queued_data_length += len(event.data) + source_conn.h2.safe_increment_flow_control(event.stream_id, event.flow_controlled_length) + return True + + def _handle_stream_ended(self, eid): + self.streams[eid].timestamp_end = time.time() + self.streams[eid].data_finished.set() + return True + + def _handle_stream_reset(self, eid, event, is_server, other_conn): + self.streams[eid].zombie = time.time() + if eid in self.streams and event.error_code == h2.errors.CANCEL: + if is_server: + other_stream_id = self.streams[eid].client_stream_id + else: + other_stream_id = self.streams[eid].server_stream_id + if other_stream_id is not None: + other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + return True + + def _handle_remote_settings_changed(self, event, other_conn): + new_settings = dict([(key, cs.new_value) for (key, cs) in six.iteritems(event.changed_settings)]) + other_conn.h2.safe_update_settings(new_settings) + return True + + def _handle_connection_terminated(self, event): + if event.error_code != h2.errors.NO_ERROR: + # Something terrible has happened - kill everything! + self.client_conn.h2.close_connection( + error_code=event.error_code, + last_stream_id=event.last_stream_id, + additional_data=event.additional_data + ) + self.client_conn.send(self.client_conn.h2.data_to_send()) + self._kill_all_streams() + else: + """ + Do not immediately terminate the other connection. + Some streams might be still sending data to the client. + """ + return False + + def _handle_pushed_stream_received(self, event): + # pushed stream ids should be unique and not dependent on race conditions + # only the parent stream id must be looked up first + parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] + with self.client_conn.h2.lock: + self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) + self.client_conn.send(self.client_conn.h2.data_to_send()) + + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id].timestamp_start = time.time() + self.streams[event.pushed_stream_id].pushed = True + self.streams[event.pushed_stream_id].parent_stream_id = parent_eid + self.streams[event.pushed_stream_id].timestamp_end = time.time() + self.streams[event.pushed_stream_id].request_data_finished.set() + self.streams[event.pushed_stream_id].start() + return True + + def _handle_priority_updated(self, eid, event): + if eid in self.streams and self.streams[eid].handled_priority_event is event: + # this event was already handled during stream creation + # HeadersFrame + Priority information as RequestReceived + return True + + with self.server_conn.h2.lock: mapped_stream_id = event.stream_id if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id: # if the stream is already up and running and was sent to the server @@ -225,17 +267,13 @@ class Http2Layer(base.Layer): self.streams[eid].priority_depends_on = event.depends_on self.streams[eid].priority_weight = event.weight - with self.server_conn.h2.lock: - self.server_conn.h2.prioritize( - mapped_stream_id, - weight=event.weight, - depends_on=self._map_depends_on_stream_id(mapped_stream_id, event.depends_on), - exclusive=event.exclusive - ) - self.server_conn.send(self.server_conn.h2.data_to_send()) - elif isinstance(event, events.TrailersReceived): - raise NotImplementedError("TrailersReceived not implemented") - + self.server_conn.h2.prioritize( + mapped_stream_id, + weight=event.weight, + depends_on=self._map_depends_on_stream_id(mapped_stream_id, event.depends_on), + exclusive=event.exclusive + ) + self.server_conn.send(self.server_conn.h2.data_to_send()) return True def _map_depends_on_stream_id(self, stream_id, depends_on): @@ -337,6 +375,15 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.priority_weight = None self.handled_priority_event = None + def check_close_connection(self, flow): + # This layer only handles a single stream. + # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream. + return True + + def set_server(self, *args, **kwargs): # pragma: no cover + # do not mess with the server connection - all streams share it. + pass + @property def data_queue(self): if self.response_arrived.is_set(): @@ -428,15 +475,25 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) headers.insert(0, ":method", message.method) headers.insert(0, ":scheme", message.scheme) + priority_exclusive = None + priority_depends_on = None + priority_weight = None + if self.handled_priority_event: + # only send priority information if they actually came with the original HeadersFrame + # and not if they got updated before/after with a PriorityFrame + priority_exclusive = self.priority_exclusive + priority_depends_on = self._map_depends_on_stream_id(self.server_stream_id, self.priority_depends_on) + priority_weight = self.priority_weight + try: self.server_conn.h2.safe_send_headers( self.is_zombie, self.server_stream_id, headers, end_stream=self.no_body, - priority_exclusive=self.priority_exclusive, - priority_depends_on=self._map_depends_on_stream_id(self.server_stream_id, self.priority_depends_on), - priority_weight=self.priority_weight, + priority_exclusive=priority_exclusive, + priority_depends_on=priority_depends_on, + priority_weight=priority_weight, ) except Exception as e: # pragma: no cover raise e @@ -477,7 +534,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) while True: try: yield self.response_data_queue.get(timeout=1) - except queue.Empty: + except queue.Empty: # pragma: no cover pass if self.response_data_finished.is_set(): if self.zombie: # pragma: no cover @@ -512,19 +569,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") - def check_close_connection(self, flow): - # This layer only handles a single stream. - # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream. - return True - - def set_server(self, *args, **kwargs): # pragma: no cover - # do not mess with the server connection - all streams share it. - pass - def run(self): - self() - - def __call__(self): layer = http.HttpLayer(self, self.mode) try: diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index a100ac2dd..55a303040 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -150,13 +150,17 @@ class _Http2TestBase(object): stream_id=1, headers=[], body=b'', + end_stream=None, priority_exclusive=None, priority_depends_on=None, priority_weight=None): + if end_stream is None: + end_stream = (len(body) == 0) + h2_conn.send_headers( stream_id=stream_id, headers=headers, - end_stream=(len(body) == 0), + end_stream=end_stream, priority_exclusive=priority_exclusive, priority_depends_on=priority_depends_on, priority_weight=priority_weight, @@ -375,6 +379,153 @@ class TestRequestWithPriority(_Http2Test): assert 'priority_weight' not in self.master.state.flows[0].response.headers +@requires_alpn +class TestPriority(_Http2Test): + priority_data = None + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + self.priority_data = (event.exclusive, event.depends_on, event.weight) + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority(self): + client, h2_conn = self._setup_connection() + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == (True, 0, 42) + + +@requires_alpn +class TestPriorityWithExistingStream(_Http2Test): + priority_data = [] + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + self.priority_data.append((event.exclusive, event.depends_on, event.weight)) + elif isinstance(event, h2.events.RequestReceived): + assert not event.priority_updated + + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + elif isinstance(event, h2.events.StreamEnded): + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority_with_existing_stream(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + end_stream=False, + ) + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + h2_conn.end_stream(1) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == [(True, 0, 42)] + + @requires_alpn class TestStreamResetFromServer(_Http2Test):