Merge pull request #1481 from Kriechi/http2-simplify-zombies
http2: simplify zombie detection
This commit is contained in:
commit
bfe22e739c
|
@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division
|
|||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import functools
|
||||
|
||||
import h2.exceptions
|
||||
import six
|
||||
|
@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection):
|
|||
self.update_settings(new_settings)
|
||||
self.conn.send(self.data_to_send())
|
||||
|
||||
def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs):
|
||||
def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs):
|
||||
with self.lock:
|
||||
if is_zombie(): # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
raise_zombie()
|
||||
self.send_headers(stream_id, headers.fields, **kwargs)
|
||||
self.conn.send(self.data_to_send())
|
||||
|
||||
def safe_send_body(self, is_zombie, stream_id, chunks):
|
||||
def safe_send_body(self, raise_zombie, stream_id, chunks):
|
||||
for chunk in chunks:
|
||||
position = 0
|
||||
while position < len(chunk):
|
||||
self.lock.acquire()
|
||||
if is_zombie(): # pragma: no cover
|
||||
self.lock.release()
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
raise_zombie(self.lock.release)
|
||||
max_outbound_frame_size = self.max_outbound_frame_size
|
||||
frame_chunk = chunk[position:position + max_outbound_frame_size]
|
||||
if self.local_flow_control_window(stream_id) < len(frame_chunk):
|
||||
|
@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection):
|
|||
self.lock.release()
|
||||
position += max_outbound_frame_size
|
||||
with self.lock:
|
||||
if is_zombie(): # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
raise_zombie()
|
||||
self.end_stream(stream_id)
|
||||
self.conn.send(self.data_to_send())
|
||||
|
||||
|
@ -344,6 +341,17 @@ class Http2Layer(base.Layer):
|
|||
self._kill_all_streams()
|
||||
|
||||
|
||||
def detect_zombie_stream(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self.raise_zombie()
|
||||
result = func(self, *args, **kwargs)
|
||||
self.raise_zombie()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
|
||||
|
||||
def __init__(self, ctx, stream_id, request_headers):
|
||||
|
@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
def queued_data_length(self, v):
|
||||
self.request_queued_data_length = v
|
||||
|
||||
def is_zombie(self):
|
||||
return self.zombie is not None
|
||||
def raise_zombie(self, pre_command=None):
|
||||
if self.zombie is not None:
|
||||
if pre_command is not None:
|
||||
pre_command()
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_request(self):
|
||||
self.request_data_finished.wait()
|
||||
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
data = []
|
||||
while self.request_data_queue.qsize() > 0:
|
||||
data.append(self.request_data_queue.get())
|
||||
|
@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
def read_request_body(self, request): # pragma: no cover
|
||||
raise NotImplementedError()
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request(self, message):
|
||||
if self.pushed:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
while True:
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
self.raise_zombie()
|
||||
self.server_conn.h2.lock.acquire()
|
||||
|
||||
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
|
||||
|
@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
break
|
||||
|
||||
# We must not assign a stream id if we are already a zombie.
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
self.raise_zombie()
|
||||
|
||||
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
|
||||
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
|
||||
|
@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
|
||||
try:
|
||||
self.server_conn.h2.safe_send_headers(
|
||||
self.is_zombie,
|
||||
self.raise_zombie,
|
||||
self.server_stream_id,
|
||||
headers,
|
||||
end_stream=self.no_body,
|
||||
|
@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
|
||||
if not self.no_body:
|
||||
self.server_conn.h2.safe_send_body(
|
||||
self.is_zombie,
|
||||
self.raise_zombie,
|
||||
self.server_stream_id,
|
||||
[message.body]
|
||||
)
|
||||
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_headers(self):
|
||||
self.response_arrived.wait()
|
||||
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
self.raise_zombie()
|
||||
|
||||
status_code = int(self.response_headers.get(':status', 502))
|
||||
headers = self.response_headers.copy()
|
||||
|
@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
timestamp_end=self.timestamp_end,
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_body(self, request, response):
|
||||
while True:
|
||||
try:
|
||||
|
@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
except queue.Empty: # pragma: no cover
|
||||
pass
|
||||
if self.response_data_finished.is_set():
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
self.raise_zombie()
|
||||
while self.response_data_queue.qsize() > 0:
|
||||
yield self.response_data_queue.get()
|
||||
break
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
self.raise_zombie()
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_headers(self, response):
|
||||
headers = response.headers.copy()
|
||||
headers.insert(0, ":status", str(response.status_code))
|
||||
|
@ -556,21 +560,21 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
|
|||
del headers[forbidden_header]
|
||||
with self.client_conn.h2.lock:
|
||||
self.client_conn.h2.safe_send_headers(
|
||||
self.is_zombie,
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
headers
|
||||
)
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_body(self, _response, chunks):
|
||||
self.client_conn.h2.safe_send_body(
|
||||
self.is_zombie,
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
chunks
|
||||
)
|
||||
if self.zombie: # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("Zombie Stream")
|
||||
|
||||
def __call__(self):
|
||||
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
|
||||
|
||||
def run(self):
|
||||
layer = http.HttpLayer(self, self.mode)
|
||||
|
|
|
@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test):
|
|||
def test_max_concurrent_streams(self):
|
||||
client, h2_conn = self._setup_connection()
|
||||
new_streams = [1, 3, 5, 7, 9, 11]
|
||||
for id in new_streams:
|
||||
for stream_id in new_streams:
|
||||
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
|
||||
# and cause mitmproxy to throttle stream creation to the server
|
||||
self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
|
||||
self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
|
||||
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
('X-Stream-ID', str(id)),
|
||||
('X-Stream-ID', str(stream_id)),
|
||||
])
|
||||
|
||||
ended_streams = 0
|
||||
|
|
Loading…
Reference in New Issue