diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 1cc127921..98728c8ad 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -9,6 +9,7 @@ import six from h2.connection import H2Connection from h2.exceptions import StreamClosedError from h2 import events +from hyperframe.frame import PriorityFrame from netlib.tcp import ssl_read_select from netlib.exceptions import HttpException @@ -169,11 +170,12 @@ class Http2Layer(Layer): # Some streams might be still sending data to the client. return False elif isinstance(event, events.PushedStreamReceived): - # pushed stream ids should be uniq and not dependent on race conditions + # pushed stream ids should be unique and not dependent on race conditions # only the parent stream id must be looked up first parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] with self.client_conn.h2.lock: self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) + self.client_conn.send(self.client_conn.h2.data_to_send()) headers = Headers([[str(k), str(v)] for k, v in event.headers]) headers['x-mitmproxy-pushed'] = 'true' @@ -184,6 +186,17 @@ class Http2Layer(Layer): self.streams[event.pushed_stream_id].timestamp_end = time.time() self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].start() + elif isinstance(event, events.PriorityUpdated): + stream_id = event.stream_id + if stream_id in self.streams.keys() and self.streams[stream_id].server_stream_id: + stream_id = self.streams[stream_id].server_stream_id + + depends_on = event.depends_on + if depends_on in self.streams.keys() and self.streams[depends_on].server_stream_id: + depends_on = self.streams[depends_on].server_stream_id + + frame = PriorityFrame(stream_id, depends_on, event.weight, event.exclusive) + self.server_conn.send(frame.serialize()) elif isinstance(event, events.TrailersReceived): raise NotImplementedError() @@ -196,6 +209,11 @@ class Http2Layer(Layer): if zombie and zombie <= death_time: self.streams.pop(stream_id, None) + def _kill_all_streams(self): + for stream in self.streams.values(): + if not stream.zombie: + stream.zombie = time.time() + def __call__(self): if self.server_conn: self._initiate_server_conn() @@ -217,9 +235,7 @@ class Http2Layer(Layer): raw_frame = b''.join(http2_read_raw_frame(source_conn.rfile)) except: # read frame failed: connection closed - # kill all streams - for stream in self.streams.values(): - stream.zombie = time.time() + self._kill_all_streams() return incoming_events = source_conn.h2.receive_data(raw_frame) @@ -227,6 +243,8 @@ class Http2Layer(Layer): for event in incoming_events: if not self._handle_event(event, source_conn, other_conn, is_server): + # connection terminated: GoAway + self._kill_all_streams() return self._cleanup_streams()