Merge pull request #1481 from Kriechi/http2-simplify-zombies

http2: simplify zombie detection
This commit is contained in:
Thomas Kriechbaumer 2016-08-16 10:10:39 +02:00 committed by GitHub
commit bfe22e739c
2 changed files with 43 additions and 39 deletions

View File

@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division
import threading
import time
import traceback
import functools
import h2.exceptions
import six
@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection):
self.update_settings(new_settings)
self.conn.send(self.data_to_send())
def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs):
def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs):
with self.lock:
if is_zombie(): # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
raise_zombie()
self.send_headers(stream_id, headers.fields, **kwargs)
self.conn.send(self.data_to_send())
def safe_send_body(self, is_zombie, stream_id, chunks):
def safe_send_body(self, raise_zombie, stream_id, chunks):
for chunk in chunks:
position = 0
while position < len(chunk):
self.lock.acquire()
if is_zombie(): # pragma: no cover
self.lock.release()
raise exceptions.Http2ProtocolException("Zombie Stream")
raise_zombie(self.lock.release)
max_outbound_frame_size = self.max_outbound_frame_size
frame_chunk = chunk[position:position + max_outbound_frame_size]
if self.local_flow_control_window(stream_id) < len(frame_chunk):
@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection):
self.lock.release()
position += max_outbound_frame_size
with self.lock:
if is_zombie(): # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
raise_zombie()
self.end_stream(stream_id)
self.conn.send(self.data_to_send())
@ -344,6 +341,17 @@ class Http2Layer(base.Layer):
self._kill_all_streams()
def detect_zombie_stream(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.raise_zombie()
result = func(self, *args, **kwargs)
self.raise_zombie()
return result
return wrapper
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers):
@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def queued_data_length(self, v):
self.request_queued_data_length = v
def is_zombie(self):
return self.zombie is not None
def raise_zombie(self, pre_command=None):
if self.zombie is not None:
if pre_command is not None:
pre_command()
raise exceptions.Http2ProtocolException("Zombie Stream")
@detect_zombie_stream
def read_request(self):
self.request_data_finished.wait()
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_request_body(self, request): # pragma: no cover
raise NotImplementedError()
@detect_zombie_stream
def send_request(self, message):
if self.pushed:
# nothing to do here
return
while True:
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.raise_zombie()
self.server_conn.h2.lock.acquire()
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
break
# We must not assign a stream id if we are already a zombie.
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.raise_zombie()
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
try:
self.server_conn.h2.safe_send_headers(
self.is_zombie,
self.raise_zombie,
self.server_stream_id,
headers,
end_stream=self.no_body,
@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if not self.no_body:
self.server_conn.h2.safe_send_body(
self.is_zombie,
self.raise_zombie,
self.server_stream_id,
[message.body]
)
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
@detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
timestamp_end=self.timestamp_end,
)
@detect_zombie_stream
def read_response_body(self, request, response):
while True:
try:
@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except queue.Empty: # pragma: no cover
pass
if self.response_data_finished.is_set():
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.raise_zombie()
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
break
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.raise_zombie()
@detect_zombie_stream
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
@ -556,21 +560,21 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
del headers[forbidden_header]
with self.client_conn.h2.lock:
self.client_conn.h2.safe_send_headers(
self.is_zombie,
self.raise_zombie,
self.client_stream_id,
headers
)
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
@detect_zombie_stream
def send_response_body(self, _response, chunks):
self.client_conn.h2.safe_send_body(
self.is_zombie,
self.raise_zombie,
self.client_stream_id,
chunks
)
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
def __call__(self):
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
def run(self):
layer = http.HttpLayer(self, self.mode)

View File

@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test):
def test_max_concurrent_streams(self):
client, h2_conn = self._setup_connection()
new_streams = [1, 3, 5, 7, 9, 11]
for id in new_streams:
for stream_id in new_streams:
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
('X-Stream-ID', str(id)),
('X-Stream-ID', str(stream_id)),
])
ended_streams = 0