http2: general improvements
This commit is contained in:
parent
eeaed93a83
commit
8ea157775d
|
@ -26,12 +26,13 @@ class HTTP2Protocol(object):
|
|||
)
|
||||
|
||||
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
|
||||
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
|
||||
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
|
||||
|
||||
ALPN_PROTO_H2 = 'h2'
|
||||
|
||||
def __init__(self, tcp_client):
|
||||
self.tcp_client = tcp_client
|
||||
def __init__(self, tcp_handler, is_server=False):
|
||||
self.tcp_handler = tcp_handler
|
||||
self.is_server = is_server
|
||||
|
||||
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
|
||||
self.current_stream_id = None
|
||||
|
@ -39,28 +40,39 @@ class HTTP2Protocol(object):
|
|||
self.decoder = Decoder()
|
||||
|
||||
def check_alpn(self):
|
||||
alp = self.tcp_client.get_alpn_proto_negotiated()
|
||||
alp = self.tcp_handler.get_alpn_proto_negotiated()
|
||||
if alp != self.ALPN_PROTO_H2:
|
||||
raise NotImplementedError(
|
||||
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
|
||||
return True
|
||||
|
||||
def perform_connection_preface(self):
|
||||
self.tcp_client.wfile.write(
|
||||
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
|
||||
# read server settings frame
|
||||
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
|
||||
def _receive_settings(self):
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
assert isinstance(frm, frame.SettingsFrame)
|
||||
self._apply_settings(frm.settings)
|
||||
|
||||
# read setting ACK frame
|
||||
def _read_settings_ack(self):
|
||||
settings_ack_frame = self.read_frame()
|
||||
assert isinstance(settings_ack_frame, frame.SettingsFrame)
|
||||
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
|
||||
assert len(settings_ack_frame.settings) == 0
|
||||
|
||||
def perform_server_connection_preface(self):
|
||||
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
|
||||
magic = self.tcp_handler.rfile.safe_read(magic_length)
|
||||
assert magic == self.CLIENT_CONNECTION_PREFACE
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
self._receive_settings()
|
||||
self._read_settings_ack()
|
||||
|
||||
def perform_client_connection_preface(self):
|
||||
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
self._receive_settings()
|
||||
self._read_settings_ack()
|
||||
|
||||
def next_stream_id(self):
|
||||
if self.current_stream_id is None:
|
||||
self.current_stream_id = 1
|
||||
|
@ -70,11 +82,11 @@ class HTTP2Protocol(object):
|
|||
|
||||
def send_frame(self, frame):
|
||||
raw_bytes = frame.to_bytes()
|
||||
self.tcp_client.wfile.write(raw_bytes)
|
||||
self.tcp_client.wfile.flush()
|
||||
self.tcp_handler.wfile.write(raw_bytes)
|
||||
self.tcp_handler.wfile.flush()
|
||||
|
||||
def read_frame(self):
|
||||
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
if isinstance(frm, frame.SettingsFrame):
|
||||
self._apply_settings(frm.settings)
|
||||
|
||||
|
@ -139,25 +151,36 @@ class HTTP2Protocol(object):
|
|||
self._create_body(body, stream_id)))
|
||||
|
||||
def read_response(self):
|
||||
headers, body = self._receive_transmission()
|
||||
return headers[':status'], headers, body
|
||||
|
||||
def read_request(self):
|
||||
return self._receive_transmission()
|
||||
|
||||
def _receive_transmission(self):
|
||||
body_expected = True
|
||||
|
||||
header_block_fragment = b''
|
||||
body = b''
|
||||
|
||||
while True:
|
||||
frm = self.read_frame()
|
||||
if isinstance(frm, frame.HeadersFrame):
|
||||
if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame):
|
||||
header_block_fragment += frm.header_block_fragment
|
||||
if frm.flags | frame.Frame.FLAG_END_HEADERS:
|
||||
if frm.flags & frame.Frame.FLAG_END_HEADERS:
|
||||
if frm.flags & frame.Frame.FLAG_END_STREAM:
|
||||
body_expected = False
|
||||
break
|
||||
|
||||
while True:
|
||||
while body_expected:
|
||||
frm = self.read_frame()
|
||||
if isinstance(frm, frame.DataFrame):
|
||||
body += frm.payload
|
||||
if frm.flags | frame.Frame.FLAG_END_STREAM:
|
||||
if frm.flags & frame.Frame.FLAG_END_STREAM:
|
||||
break
|
||||
|
||||
headers = {}
|
||||
for header, value in self.decoder.decode(header_block_fragment):
|
||||
headers[header] = value
|
||||
|
||||
return headers[':status'], headers, body
|
||||
return headers, body
|
||||
|
|
|
@ -50,7 +50,39 @@ class TestCheckALPNMismatch(test.ServerTestBase):
|
|||
tutils.raises(NotImplementedError, protocol.check_alpn)
|
||||
|
||||
|
||||
class TestPerformConnectionPreface(test.ServerTestBase):
|
||||
class TestPerformServerConnectionPreface(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
# send magic
|
||||
self.wfile.write(\
|
||||
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
# send empty settings frame
|
||||
self.wfile.write('000000040000000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
# check empty settings frame
|
||||
assert self.rfile.read(9) ==\
|
||||
'000000040000000000'.decode('hex')
|
||||
|
||||
# check settings acknowledgement
|
||||
assert self.rfile.read(9) == \
|
||||
'000000040100000000'.decode('hex')
|
||||
|
||||
# send settings acknowledgement
|
||||
self.wfile.write('000000040100000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
def test_perform_server_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.perform_server_connection_preface()
|
||||
|
||||
|
||||
class TestPerformClientConnectionPreface(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
|
@ -74,14 +106,11 @@ class TestPerformConnectionPreface(test.ServerTestBase):
|
|||
self.wfile.write('000000040100000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_perform_connection_preface(self):
|
||||
def test_perform_client_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.perform_connection_preface()
|
||||
protocol.perform_client_connection_preface()
|
||||
|
||||
|
||||
class TestStreamIds():
|
||||
|
|
Loading…
Reference in New Issue