diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index eb5586cb2..0e42d619a 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division import threading import time import traceback +import functools import h2.exceptions import six @@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection): self.update_settings(new_settings) self.conn.send(self.data_to_send()) - def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs): + def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs): with self.lock: - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie() self.send_headers(stream_id, headers.fields, **kwargs) self.conn.send(self.data_to_send()) - def safe_send_body(self, is_zombie, stream_id, chunks): + def safe_send_body(self, raise_zombie, stream_id, chunks): for chunk in chunks: position = 0 while position < len(chunk): self.lock.acquire() - if is_zombie(): # pragma: no cover - self.lock.release() - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie(self.lock.release) max_outbound_frame_size = self.max_outbound_frame_size frame_chunk = chunk[position:position + max_outbound_frame_size] if self.local_flow_control_window(stream_id) < len(frame_chunk): @@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection): self.lock.release() position += max_outbound_frame_size with self.lock: - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie() self.end_stream(stream_id) self.conn.send(self.data_to_send()) @@ -344,6 +341,17 @@ class Http2Layer(base.Layer): self._kill_all_streams() +def detect_zombie_stream(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self.raise_zombie() + result = func(self, *args, **kwargs) + self.raise_zombie() + return result + + return wrapper + + class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): def __init__(self, ctx, stream_id, request_headers): @@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def queued_data_length(self, v): self.request_queued_data_length = v - def is_zombie(self): - return self.zombie is not None + def raise_zombie(self, pre_command=None): + if self.zombie is not None: + if pre_command is not None: + pre_command() + raise exceptions.Http2ProtocolException("Zombie Stream") + @detect_zombie_stream def read_request(self): self.request_data_finished.wait() - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - data = [] while self.request_data_queue.qsize() > 0: data.append(self.request_data_queue.get()) @@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def read_request_body(self, request): # pragma: no cover raise NotImplementedError() + @detect_zombie_stream def send_request(self, message): if self.pushed: # nothing to do here return while True: - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - + self.raise_zombie() self.server_conn.h2.lock.acquire() max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams @@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) break # We must not assign a stream id if we are already a zombie. - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id @@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) try: self.server_conn.h2.safe_send_headers( - self.is_zombie, + self.raise_zombie, self.server_stream_id, headers, end_stream=self.no_body, @@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if not self.no_body: self.server_conn.h2.safe_send_body( - self.is_zombie, + self.raise_zombie, self.server_stream_id, [message.body] ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - + @detect_zombie_stream def read_response_headers(self): self.response_arrived.wait() - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() status_code = int(self.response_headers.get(':status', 502)) headers = self.response_headers.copy() @@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) timestamp_end=self.timestamp_end, ) + @detect_zombie_stream def read_response_body(self, request, response): while True: try: @@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except queue.Empty: # pragma: no cover pass if self.response_data_finished.is_set(): - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() while self.response_data_queue.qsize() > 0: yield self.response_data_queue.get() break - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() + @detect_zombie_stream def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) @@ -556,21 +560,21 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) del headers[forbidden_header] with self.client_conn.h2.lock: self.client_conn.h2.safe_send_headers( - self.is_zombie, + self.raise_zombie, self.client_stream_id, headers ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + @detect_zombie_stream def send_response_body(self, _response, chunks): self.client_conn.h2.safe_send_body( - self.is_zombie, + self.raise_zombie, self.client_stream_id, chunks ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + + def __call__(self): + raise EnvironmentError('Http2SingleStreamLayer must be run as thread') def run(self): layer = http.HttpLayer(self, self.mode) diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index f0fa9a404..873c89c37 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test): def test_max_concurrent_streams(self): client, h2_conn = self._setup_connection() new_streams = [1, 3, 5, 7, 9, 11] - for id in new_streams: + for stream_id in new_streams: # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ + self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('X-Stream-ID', str(id)), + ('X-Stream-ID', str(stream_id)), ]) ended_streams = 0