http2: general improvements

This commit is contained in:
Thomas Kriechbaumer 2015-06-11 15:38:32 +02:00
parent eeaed93a83
commit 8ea157775d
2 changed files with 78 additions and 26 deletions

View File

@ -26,12 +26,13 @@ class HTTP2Protocol(object):
)
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
ALPN_PROTO_H2 = 'h2'
def __init__(self, tcp_client):
self.tcp_client = tcp_client
def __init__(self, tcp_handler, is_server=False):
self.tcp_handler = tcp_handler
self.is_server = is_server
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
@ -39,28 +40,39 @@ class HTTP2Protocol(object):
self.decoder = Decoder()
def check_alpn(self):
alp = self.tcp_client.get_alpn_proto_negotiated()
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def perform_connection_preface(self):
self.tcp_client.wfile.write(
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
self.send_frame(frame.SettingsFrame(state=self))
# read server settings frame
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
def _receive_settings(self):
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
assert isinstance(frm, frame.SettingsFrame)
self._apply_settings(frm.settings)
# read setting ACK frame
def _read_settings_ack(self):
settings_ack_frame = self.read_frame()
assert isinstance(settings_ack_frame, frame.SettingsFrame)
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
assert len(settings_ack_frame.settings) == 0
def perform_server_connection_preface(self):
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()
def perform_client_connection_preface(self):
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()
def next_stream_id(self):
if self.current_stream_id is None:
self.current_stream_id = 1
@ -70,11 +82,11 @@ class HTTP2Protocol(object):
def send_frame(self, frame):
raw_bytes = frame.to_bytes()
self.tcp_client.wfile.write(raw_bytes)
self.tcp_client.wfile.flush()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
def read_frame(self):
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
if isinstance(frm, frame.SettingsFrame):
self._apply_settings(frm.settings)
@ -139,25 +151,36 @@ class HTTP2Protocol(object):
self._create_body(body, stream_id)))
def read_response(self):
headers, body = self._receive_transmission()
return headers[':status'], headers, body
def read_request(self):
return self._receive_transmission()
def _receive_transmission(self):
body_expected = True
header_block_fragment = b''
body = b''
while True:
frm = self.read_frame()
if isinstance(frm, frame.HeadersFrame):
if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame):
header_block_fragment += frm.header_block_fragment
if frm.flags | frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_STREAM:
body_expected = False
break
while True:
while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame):
body += frm.payload
if frm.flags | frame.Frame.FLAG_END_STREAM:
if frm.flags & frame.Frame.FLAG_END_STREAM:
break
headers = {}
for header, value in self.decoder.decode(header_block_fragment):
headers[header] = value
return headers[':status'], headers, body
return headers, body

View File

@ -50,7 +50,39 @@ class TestCheckALPNMismatch(test.ServerTestBase):
tutils.raises(NotImplementedError, protocol.check_alpn)
class TestPerformConnectionPreface(test.ServerTestBase):
class TestPerformServerConnectionPreface(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# send magic
self.wfile.write(\
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex'))
self.wfile.flush()
# send empty settings frame
self.wfile.write('000000040000000000'.decode('hex'))
self.wfile.flush()
# check empty settings frame
assert self.rfile.read(9) ==\
'000000040000000000'.decode('hex')
# check settings acknowledgement
assert self.rfile.read(9) == \
'000000040100000000'.decode('hex')
# send settings acknowledgement
self.wfile.write('000000040100000000'.decode('hex'))
self.wfile.flush()
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
protocol = http2.HTTP2Protocol(c)
protocol.perform_server_connection_preface()
class TestPerformClientConnectionPreface(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
@ -74,14 +106,11 @@ class TestPerformConnectionPreface(test.ServerTestBase):
self.wfile.write('000000040100000000'.decode('hex'))
self.wfile.flush()
ssl = True
def test_perform_connection_preface(self):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.perform_connection_preface()
protocol.perform_client_connection_preface()
class TestStreamIds():