diff --git a/docs/ci.sh b/docs/ci.sh index 159d0b50d..3e2e9bc76 100755 --- a/docs/ci.sh +++ b/docs/ci.sh @@ -2,7 +2,6 @@ set -o errexit set -o pipefail -set -o nounset # set -o xtrace # This script gets run from CI to render and upload docs for the master branch. diff --git a/examples/addons/events-websocket-specific.py b/examples/addons/events-websocket-specific.py index 17cbd0795..1a499cd97 100644 --- a/examples/addons/events-websocket-specific.py +++ b/examples/addons/events-websocket-specific.py @@ -4,7 +4,7 @@ import mitmproxy.websocket class Events: - # Websocket lifecycle + # WebSocket lifecycle def websocket_handshake(self, flow: mitmproxy.http.HTTPFlow): """ Called when a client wants to establish a WebSocket connection. The @@ -15,7 +15,7 @@ class Events: def websocket_start(self, flow: mitmproxy.websocket.WebSocketFlow): """ - A websocket connection has commenced. + A WebSocket connection has commenced. """ def websocket_message(self, flow: mitmproxy.websocket.WebSocketFlow): @@ -28,10 +28,10 @@ class Events: def websocket_error(self, flow: mitmproxy.websocket.WebSocketFlow): """ - A websocket connection has had an error. + A WebSocket connection has had an error. """ def websocket_end(self, flow: mitmproxy.websocket.WebSocketFlow): """ - A websocket connection has ended. + A WebSocket connection has ended. """ diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index e8626c629..9767d1a6e 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -39,11 +39,8 @@ from typing import Callable, ClassVar, Optional, Sequence, Type import pyparsing as pp -from mitmproxy import flow -from mitmproxy import http -from mitmproxy import tcp -from mitmproxy import websocket -from mitmproxy.net import websocket_utils +from mitmproxy import flow, http, tcp, websocket +from mitmproxy.net.websocket import check_handshake def only(*types): @@ -110,7 +107,7 @@ class FWebSocket(_Action): @only(http.HTTPFlow, websocket.WebSocketFlow) def __call__(self, f): m = ( - (isinstance(f, http.HTTPFlow) and f.request and websocket_utils.check_handshake(f.request.headers)) + (isinstance(f, http.HTTPFlow) and f.request and check_handshake(f.request.headers)) or isinstance(f, websocket.WebSocketFlow) ) return m diff --git a/mitmproxy/net/http/http2.py b/mitmproxy/net/http/http2.py new file mode 100644 index 000000000..7b6f5a8d4 --- /dev/null +++ b/mitmproxy/net/http/http2.py @@ -0,0 +1,28 @@ +import codecs + +from hyperframe.frame import Frame + +from mitmproxy import exceptions + + +def read_frame(rfile, parse=True): + """ + Reads a full HTTP/2 frame from a file-like object. + + Returns a parsed frame and the consumed bytes. + """ + header = rfile.safe_read(9) + length = int(codecs.encode(header[:3], 'hex_codec'), 16) + + if length == 4740180: + raise exceptions.HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1))) + + body = rfile.safe_read(length) + + if parse: + frame, _ = Frame.parse_frame_header(header) + frame.parse_body(memoryview(body)) + else: + frame = None + + return frame, b''.join([header, body]) diff --git a/mitmproxy/net/http/http2/__init__.py b/mitmproxy/net/http/http2/__init__.py deleted file mode 100644 index 7027006be..000000000 --- a/mitmproxy/net/http/http2/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from mitmproxy.net.http.http2.framereader import read_raw_frame, parse_frame -from mitmproxy.net.http.http2.utils import parse_headers - -__all__ = [ - "read_raw_frame", - "parse_frame", - "parse_headers", -] diff --git a/mitmproxy/net/http/http2/framereader.py b/mitmproxy/net/http/http2/framereader.py deleted file mode 100644 index 777d247d9..000000000 --- a/mitmproxy/net/http/http2/framereader.py +++ /dev/null @@ -1,25 +0,0 @@ -import codecs - -import hyperframe.frame -from mitmproxy import exceptions - - -def read_raw_frame(rfile): - header = rfile.safe_read(9) - length = int(codecs.encode(header[:3], 'hex_codec'), 16) - - if length == 4740180: - raise exceptions.HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1))) - - body = rfile.safe_read(length) - return [header, body] - - -def parse_frame(header, body=None): - if body is None: - body = header[9:] - header = header[:9] - - frame, _ = hyperframe.frame.Frame.parse_frame_header(header) - frame.parse_body(memoryview(body)) - return frame diff --git a/mitmproxy/net/http/http2/utils.py b/mitmproxy/net/http/http2/utils.py deleted file mode 100644 index 4a553d8df..000000000 --- a/mitmproxy/net/http/http2/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -from mitmproxy.net.http import url - - -def parse_headers(headers): - authority = headers.get(':authority', '').encode() - method = headers.get(':method', 'GET').encode() - scheme = headers.get(':scheme', 'https').encode() - path = headers.get(':path', '/').encode() - - headers.pop(":method", None) - headers.pop(":scheme", None) - headers.pop(":path", None) - - host = None - port = None - - if method == b'CONNECT': - raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") - - if path == b'*' or path.startswith(b"/"): - first_line_format = "relative" - else: - first_line_format = "absolute" - scheme, host, port, _ = url.parse(path) - - if authority: - host, _, port = authority.partition(b':') - - if not host: - host = b'localhost' - - if not port: - port = 443 if scheme == b'https' else 80 - - port = int(port) - - return first_line_format, method, scheme, host, port, path diff --git a/mitmproxy/net/websocket_utils.py b/mitmproxy/net/websocket.py similarity index 78% rename from mitmproxy/net/websocket_utils.py rename to mitmproxy/net/websocket.py index f8608b07b..930b33d4c 100644 --- a/mitmproxy/net/websocket_utils.py +++ b/mitmproxy/net/websocket.py @@ -17,7 +17,13 @@ from mitmproxy.net import http from mitmproxy.utils import bits, strutils -def read_raw_frame(rfile): +def read_frame(rfile, parse=True): + """ + Reads a full WebSocket frame from a file-like object. + + Returns a parsed frame header, parsed frame, and the consumed bytes. + """ + consumed_bytes = b'' def consume(len): @@ -52,34 +58,36 @@ def read_raw_frame(rfile): masking_key = None masker = XorMaskerNull() - header = Header( - fin=fin, - rsv=RsvBits(rsv1, rsv2, rsv3), - opcode=opcode, - payload_len=payload_len, - masking_key=masking_key, - ) - masked_payload = consume(payload_len) - payload = masker.process(masked_payload) - frame = Frame( - opcode=opcode, - payload=payload, - frame_finished=fin, - message_finished=fin - ) + if parse: + header = Header( + fin=fin, + rsv=RsvBits(rsv1, rsv2, rsv3), + opcode=opcode, + payload_len=payload_len, + masking_key=masking_key, + ) + frame = Frame( + opcode=opcode, + payload=masker.process(masked_payload), + frame_finished=fin, + message_finished=fin + ) + else: + header = None + frame = None return header, frame, consumed_bytes def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. - Returns an instance of http.Headers + Returns an instance of http.Headers """ if version is None: version = WEBSOCKET_VERSION @@ -100,9 +108,9 @@ def client_handshake_headers(version=None, key=None, protocol=None, extensions=N def server_handshake_headers(client_key, protocol=None, extensions=None): """ - The server response is a valid HTTP 101 response. + The server response is a valid HTTP 101 response. - Returns an instance of http.Headers + Returns an instance of http.Headers """ h = http.Headers( connection="upgrade", diff --git a/mitmproxy/proxy/protocol/__init__.py b/mitmproxy/proxy/protocol/__init__.py index 5860542a8..ccf85d006 100644 --- a/mitmproxy/proxy/protocol/__init__.py +++ b/mitmproxy/proxy/protocol/__init__.py @@ -15,7 +15,7 @@ mitmproxy connection may look as follows (outermost layer first): - Http1Layer - HttpLayer - TLSLayer - - WebsocketLayer (or TCPLayer) + - WebSocketLayer (or TCPLayer) Every layer acts as a read-only context for its inner layers (see :py:class:`Layer`). To communicate with an outer layer, a layer can use diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 1b4e5502f..91b9212d2 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -10,7 +10,7 @@ from mitmproxy import http from mitmproxy import flow from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol.websocket import WebSocketLayer -from mitmproxy.net import websocket_utils +from mitmproxy.net import websocket class _HttpTransmissionLayer(base.Layer): @@ -343,8 +343,8 @@ class HttpLayer(base.Layer): try: valid = ( - websocket_utils.check_handshake(request.headers) and - websocket_utils.check_client_version(request.headers) + websocket.check_handshake(request.headers) and + websocket.check_client_version(request.headers) ) if valid: f.metadata['websocket'] = True @@ -462,8 +462,8 @@ class HttpLayer(base.Layer): # received after e.g. a WebSocket upgrade request. # Check for WebSocket handshake is_websocket = ( - websocket_utils.check_handshake(f.request.headers) and - websocket_utils.check_handshake(f.response.headers) + websocket.check_handshake(f.request.headers) and + websocket.check_handshake(f.response.headers) ) if is_websocket and not self.config.options.websocket: self.log( diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index 787871e5d..f0547e426 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -360,7 +360,7 @@ class Http2Layer(base.Layer): with self.connections[source_conn].lock: try: - raw_frame = b''.join(http2.read_raw_frame(source_conn.rfile)) + _, consumed_bytes = http2.read_frame(source_conn.rfile) except: # read frame failed: connection closed self._kill_all_streams() @@ -370,7 +370,7 @@ class Http2Layer(base.Layer): self.log("HTTP/2 connection entered closed state already", "debug") return - incoming_events = self.connections[source_conn].receive_data(raw_frame) + incoming_events = self.connections[source_conn].receive_data(consumed_bytes) source_conn.send(self.connections[source_conn].data_to_send()) for event in incoming_events: diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index e7697169c..2b9742357 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -10,7 +10,7 @@ from wsproto.extensions import PerMessageDeflate from mitmproxy import exceptions, flow from mitmproxy.proxy.protocol import base -from mitmproxy.net import tcp, websocket_utils +from mitmproxy.net import tcp, websocket from mitmproxy.websocket import WebSocketFlow, WebSocketMessage from mitmproxy.utils import strutils @@ -200,7 +200,7 @@ class WebSocketLayer(base.Layer): other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn is_server = (source_conn == self.server_conn) - header, frame, consumed_bytes = websocket_utils.read_raw_frame(source_conn.rfile) + header, frame, consumed_bytes = websocket.read_frame(source_conn.rfile) self.log( "WebSocket Frame from {}: {}, {}".format( "server" if is_server else "client", diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 124204e4d..f7bc7d397 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -6,7 +6,7 @@ from wsproto.frame_protocol import CloseReason from wsproto.frame_protocol import Opcode from mitmproxy import flow -from mitmproxy.net import websocket_utils +from mitmproxy.net import websocket from mitmproxy.coretypes import serializable from mitmproxy.utils import strutils, human @@ -58,7 +58,7 @@ class WebSocketMessage(serializable.Serializable): class WebSocketFlow(flow.Flow): """ - A WebSocketFlow is a simplified representation of a Websocket connection. + A WebSocketFlow is a simplified representation of a WebSocket connection. """ def __init__(self, client_conn, server_conn, handshake_flow, live=None): @@ -85,12 +85,12 @@ class WebSocketFlow(flow.Flow): self._inject_messages_server = queue.Queue(maxsize=1) if handshake_flow: - self.client_key = websocket_utils.get_client_key(handshake_flow.request.headers) - self.client_protocol = websocket_utils.get_protocol(handshake_flow.request.headers) - self.client_extensions = websocket_utils.get_extensions(handshake_flow.request.headers) - self.server_accept = websocket_utils.get_server_accept(handshake_flow.response.headers) - self.server_protocol = websocket_utils.get_protocol(handshake_flow.response.headers) - self.server_extensions = websocket_utils.get_extensions(handshake_flow.response.headers) + self.client_key = websocket.get_client_key(handshake_flow.request.headers) + self.client_protocol = websocket.get_protocol(handshake_flow.request.headers) + self.client_extensions = websocket.get_extensions(handshake_flow.request.headers) + self.server_accept = websocket.get_server_accept(handshake_flow.response.headers) + self.server_protocol = websocket.get_protocol(handshake_flow.response.headers) + self.server_extensions = websocket.get_extensions(handshake_flow.response.headers) else: self.client_key = '' self.client_protocol = '' diff --git a/pathod/language/http.py b/pathod/language/http.py index d496defc5..dc6de11e9 100644 --- a/pathod/language/http.py +++ b/pathod/language/http.py @@ -2,7 +2,7 @@ import abc import pyparsing as pp -from mitmproxy.net import websocket_utils +from mitmproxy.net import websocket from mitmproxy.net.http import status_codes, url, user_agents from . import base, exceptions, actions, message @@ -199,7 +199,7 @@ class Response(_HTTPMessage): 1, StatusCode(101) ) - headers = websocket_utils.server_handshake_headers( + headers = websocket.server_handshake_headers( settings.websocket_key ) for i in headers.fields: @@ -311,7 +311,7 @@ class Request(_HTTPMessage): 1, Method("get") ) - for i in websocket_utils.client_handshake_headers().fields: + for i in websocket.client_handshake_headers().fields: if not get_header(i[0], self.headers): tokens.append( Header( diff --git a/pathod/pathod.py b/pathod/pathod.py index 80ad0c31c..4bc307801 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -3,18 +3,14 @@ import logging import os import sys import threading -from mitmproxy.net import tcp, tls -from mitmproxy import certs as mcerts -from mitmproxy.net import websocket_utils -from mitmproxy import version import urllib -from mitmproxy import exceptions -from pathod import language -from pathod import utils -from pathod import log -from pathod import protocols import typing # noqa +from mitmproxy import certs as mcerts, exceptions, version +from mitmproxy.net import tcp, tls, websocket + +from pathod import language, utils, log, protocols + DEFAULT_CERT_DOMAIN = b"pathod.net" CONFDIR = "~/.mitmproxy" @@ -177,8 +173,8 @@ class PathodHandler(tcp.BaseHandler): m = utils.MemBool() - valid_websocket_handshake = websocket_utils.check_handshake(headers) - self.settings.websocket_key = websocket_utils.get_client_key(headers) + valid_websocket_handshake = websocket.check_handshake(headers) + self.settings.websocket_key = websocket.get_client_key(headers) # If this is a websocket initiation, we respond with a proper # server response, unless over-ridden. diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index c258ac09f..966a0e460 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -260,7 +260,8 @@ class HTTP2StateProtocol: def read_frame(self, hide=False): while True: - frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile)) + frm, _ = http2.read_frame(self.tcp_handler.rfile) + if not hide and self.dump_frames: # pragma: no cover print("<< " + repr(frm)) diff --git a/test/mitmproxy/net/http/http2/__init__.py b/test/mitmproxy/net/http/http2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/mitmproxy/net/http/http2/test_framereader.py b/test/mitmproxy/net/http/http2/test_framereader.py deleted file mode 100644 index 485ba69fc..000000000 --- a/test/mitmproxy/net/http/http2/test_framereader.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import codecs -from io import BytesIO -import hyperframe.frame - -from mitmproxy import exceptions -from mitmproxy.net.http.http2 import read_raw_frame, parse_frame - - -def test_read_raw_frame(): - raw = codecs.decode('000006000101234567666f6f626172', 'hex_codec') - bio = BytesIO(raw) - bio.safe_read = bio.read - - header, body = read_raw_frame(bio) - assert header - assert body - - -def test_read_raw_frame_failed(): - raw = codecs.decode('485454000000000000', 'hex_codec') - bio = BytesIO(raw) - bio.safe_read = bio.read - - with pytest.raises(exceptions.HttpException): - read_raw_frame(bio) - - -def test_parse_frame(): - f = parse_frame( - codecs.decode('000006000101234567', 'hex_codec'), - codecs.decode('666f6f626172', 'hex_codec') - ) - assert isinstance(f, hyperframe.frame.Frame) - - -def test_parse_frame_combined(): - f = parse_frame( - codecs.decode('000006000101234567666f6f626172', 'hex_codec'), - ) - assert isinstance(f, hyperframe.frame.Frame) diff --git a/test/mitmproxy/net/http/http2/test_utils.py b/test/mitmproxy/net/http/http2/test_utils.py deleted file mode 100644 index 41d49b6f6..000000000 --- a/test/mitmproxy/net/http/http2/test_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from mitmproxy.net.http.http2 import parse_headers - - -class TestHttp2ParseHeaders: - - def test_relative(self): - h = dict([ - (':authority', "127.0.0.1:1234"), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ]) - first_line_format, method, scheme, host, port, path = parse_headers(h) - assert first_line_format == 'relative' - assert method == b'GET' - assert scheme == b'https' - assert host == b'127.0.0.1' - assert port == 1234 - assert path == b'/' - - def test_absolute(self): - h = dict([ - (':authority', "127.0.0.1:1234"), - (':method', 'GET'), - (':scheme', 'https'), - (':path', 'https://127.0.0.1:4321'), - ]) - first_line_format, method, scheme, host, port, path = parse_headers(h) - assert first_line_format == 'absolute' - assert method == b'GET' - assert scheme == b'https' - assert host == b'127.0.0.1' - assert port == 1234 - assert path == b'https://127.0.0.1:4321' - - @pytest.mark.parametrize("scheme, expected_port", [ - ('http', 80), - ('https', 443), - ]) - def test_without_port(self, scheme, expected_port): - h = dict([ - (':authority', "127.0.0.1"), - (':method', 'GET'), - (':scheme', scheme), - (':path', '/'), - ]) - _, _, _, _, port, _ = parse_headers(h) - assert port == expected_port - - def test_without_authority(self): - h = dict([ - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ]) - _, _, _, host, _, _ = parse_headers(h) - assert host == b'localhost' - - def test_connect(self): - h = dict([ - (':authority', "127.0.0.1"), - (':method', 'CONNECT'), - (':scheme', 'https'), - (':path', '/'), - ]) - - with pytest.raises(NotImplementedError): - parse_headers(h) diff --git a/test/mitmproxy/net/http/test_http2.py b/test/mitmproxy/net/http/test_http2.py new file mode 100644 index 000000000..bbbf98aaf --- /dev/null +++ b/test/mitmproxy/net/http/test_http2.py @@ -0,0 +1,37 @@ +import pytest +import codecs +from io import BytesIO + +import hyperframe + +from mitmproxy import exceptions +from mitmproxy.net.http import http2 + + +def test_read_frame(): + raw = codecs.decode('000006000101234567666f6f626172', 'hex_codec') + bio = BytesIO(raw) + bio.safe_read = bio.read + + frame, consumed_bytes = http2.read_frame(bio) + assert isinstance(frame, hyperframe.frame.DataFrame) + assert frame.stream_id == 19088743 + assert 'END_STREAM' in frame.flags + assert len(frame.flags) == 1 + assert frame.data == b'foobar' + assert consumed_bytes == raw + + bio = BytesIO(raw) + bio.safe_read = bio.read + frame, consumed_bytes = http2.read_frame(bio, False) + assert frame is None + assert consumed_bytes == raw + + +def test_read_frame_failed(): + raw = codecs.decode('485454000000000000', 'hex_codec') + bio = BytesIO(raw) + bio.safe_read = bio.read + + with pytest.raises(exceptions.HttpException): + _ = http2.read_frame(bio, False) diff --git a/test/mitmproxy/net/test_websocket_utils.py b/test/mitmproxy/net/test_websocket.py similarity index 60% rename from test/mitmproxy/net/test_websocket_utils.py rename to test/mitmproxy/net/test_websocket.py index 3c0dbbe4a..c38f9375d 100644 --- a/test/mitmproxy/net/test_websocket_utils.py +++ b/test/mitmproxy/net/test_websocket.py @@ -4,8 +4,7 @@ from unittest import mock from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame -from mitmproxy.net.http import Headers -from mitmproxy.net import websocket_utils +from mitmproxy.net import http, websocket @pytest.mark.parametrize("input,masking_key,payload_length", [ @@ -14,11 +13,11 @@ from mitmproxy.net import websocket_utils (b'\x01~\x04\x00server-foobar', None, 1024), (b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072), ]) -def test_read_raw_frame(input, masking_key, payload_length): +def test_read_frame(input, masking_key, payload_length): bio = BytesIO(input) bio.safe_read = bio.read - header, frame, consumed_bytes = websocket_utils.read_raw_frame(bio) + header, frame, consumed_bytes = websocket.read_frame(bio) assert header == \ Header( fin=False, @@ -36,18 +35,25 @@ def test_read_raw_frame(input, masking_key, payload_length): ) assert consumed_bytes == input + bio = BytesIO(input) + bio.safe_read = bio.read + header, frame, consumed_bytes = websocket.read_frame(bio, False) + assert header is None + assert frame is None + assert consumed_bytes == input + @mock.patch('os.urandom', return_value=b'pumpkinspumpkins') def test_client_handshake_headers(_): - assert websocket_utils.client_handshake_headers() == \ - Headers([ + assert websocket.client_handshake_headers() == \ + http.Headers([ (b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-version', b'13'), (b'sec-websocket-key', b'cHVtcGtpbnNwdW1wa2lucw=='), ]) - assert websocket_utils.client_handshake_headers(b"13", b"foobar", b"foo", b"bar") == \ - Headers([ + assert websocket.client_handshake_headers(b"13", b"foobar", b"foo", b"bar") == \ + http.Headers([ (b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-version', b'13'), @@ -58,8 +64,8 @@ def test_client_handshake_headers(_): def test_server_handshake_headers(): - assert websocket_utils.server_handshake_headers("foobar", "foo", "bar") == \ - Headers([ + assert websocket.server_handshake_headers("foobar", "foo", "bar") == \ + http.Headers([ (b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), @@ -69,17 +75,17 @@ def test_server_handshake_headers(): def test_check_handshake(): - assert not websocket_utils.check_handshake({ + assert not websocket.check_handshake({ "connection": "upgrade", "upgrade": "webFOOsocket", "sec-websocket-key": "foo", }) - assert websocket_utils.check_handshake({ + assert websocket.check_handshake({ "connection": "upgrade", "upgrade": "websocket", "sec-websocket-key": "foo", }) - assert websocket_utils.check_handshake({ + assert websocket.check_handshake({ "connection": "upgrade", "upgrade": "websocket", "sec-websocket-accept": "bar", @@ -87,30 +93,30 @@ def test_check_handshake(): def test_create_server_nonce(): - assert websocket_utils.create_server_nonce(b"foobar") == b"AzhRPA4TNwR6I/riJheN0TfR7+I=" + assert websocket.create_server_nonce(b"foobar") == b"AzhRPA4TNwR6I/riJheN0TfR7+I=" def test_check_client_version(): - assert not websocket_utils.check_client_version({}) - assert not websocket_utils.check_client_version({"sec-websocket-version": b"42"}) - assert websocket_utils.check_client_version({"sec-websocket-version": b"13"}) + assert not websocket.check_client_version({}) + assert not websocket.check_client_version({"sec-websocket-version": b"42"}) + assert websocket.check_client_version({"sec-websocket-version": b"13"}) def test_get_extensions(): - assert websocket_utils.get_extensions({}) is None - assert websocket_utils.get_extensions({"sec-websocket-extensions": "foo"}) == "foo" + assert websocket.get_extensions({}) is None + assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo" def test_get_protocol(): - assert websocket_utils.get_protocol({}) is None - assert websocket_utils.get_protocol({"sec-websocket-protocol": "foo"}) == "foo" + assert websocket.get_protocol({}) is None + assert websocket.get_protocol({"sec-websocket-protocol": "foo"}) == "foo" def test_get_client_key(): - assert websocket_utils.get_client_key({}) is None - assert websocket_utils.get_client_key({"sec-websocket-key": "foo"}) == "foo" + assert websocket.get_client_key({}) is None + assert websocket.get_client_key({"sec-websocket-key": "foo"}) == "foo" def test_get_server_accept(): - assert websocket_utils.get_server_accept({}) is None - assert websocket_utils.get_server_accept({"sec-websocket-accept": "foo"}) == "foo" + assert websocket.get_server_accept({}) is None + assert websocket.get_server_accept({"sec-websocket-accept": "foo"}) == "foo" diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 4db9f1adb..76fa180c4 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -56,8 +56,8 @@ class _Http2ServerBase(net_tservers.ServerTestBase): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -246,8 +246,8 @@ class TestSimpleRequestWithBody(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -320,8 +320,8 @@ class TestSimpleRequestWithoutBody(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -401,8 +401,8 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -489,8 +489,8 @@ class TestPriority(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -539,8 +539,8 @@ class TestStreamResetFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -606,8 +606,8 @@ class TestAllStreamResetsFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -653,8 +653,8 @@ class TestBodySizeLimit(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -750,8 +750,8 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -806,8 +806,8 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -865,8 +865,8 @@ class TestConnectionLost(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -921,8 +921,8 @@ class TestMaxConcurrentStreams(_Http2Test): ended_streams = 0 while ended_streams != len(new_streams): try: - header, body = http2.read_raw_frame(self.client.rfile) - events = h2_conn.receive_data(b''.join([header, body])) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except: break self.client.wfile.write(h2_conn.data_to_send()) @@ -966,8 +966,8 @@ class TestConnectionTerminated(_Http2Test): connection_terminated_event = None while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) for event in events: if isinstance(event, h2.events.ConnectionTerminated): connection_terminated_event = event @@ -991,7 +991,6 @@ class TestRequestStreaming(_Http2Test): elif isinstance(event, h2.events.DataReceived): data = event.data assert data - print(event) h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=data) wfile.write(h2_conn.data_to_send()) wfile.flush() @@ -1025,16 +1024,21 @@ class TestRequestStreaming(_Http2Test): self.client.rfile.o.settimeout(2) while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) for event in events: if isinstance(event, h2.events.ConnectionTerminated): connection_terminated_event = event done = True + except mitmproxy.exceptions.TcpTimeout: + if not streaming: + break # this is expected for this test case + else: + assert False except: print(traceback.format_exc()) - break + assert False if streaming: assert connection_terminated_event.additional_data == body @@ -1083,8 +1087,8 @@ class TestResponseStreaming(_Http2Test): data = None while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) for event in events: if isinstance(event, h2.events.DataReceived): @@ -1150,8 +1154,8 @@ class TestRequestTrailers(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False @@ -1214,8 +1218,8 @@ class TestResponseTrailers(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(self.client.rfile)) - events = h2_conn.receive_data(raw) + _, consumed_bytes = http2.read_frame(self.client.rfile, False) + events = h2_conn.receive_data(consumed_bytes) except exceptions.HttpException: print(traceback.format_exc()) assert False diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 7e28cdfb0..bbaeb1853 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -9,7 +9,7 @@ from wsproto.frame_protocol import Opcode from mitmproxy import exceptions, options from mitmproxy.http import HTTPFlow, make_connect_request from mitmproxy.websocket import WebSocketFlow -from mitmproxy.net import http, tcp, websocket_utils +from mitmproxy.net import http, tcp, websocket from pathod.language import websockets_frame @@ -24,7 +24,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase): def handle(self): try: request = http.http1.read_request(self.rfile) - assert websocket_utils.check_handshake(request.headers) + assert websocket.check_handshake(request.headers) response = http.Response( http_version=b"HTTP/1.1", @@ -123,7 +123,7 @@ class _WebSocketTestBase: self.client.wfile.flush() response = http.http1.read_response(self.client.rfile, request) - assert websocket_utils.check_handshake(response.headers) + assert websocket.check_handshake(response.headers) class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): @@ -146,11 +146,11 @@ class TestSimple(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar'))) wfile.flush() - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() @@ -163,19 +163,19 @@ class TestSimple(_WebSocketTest): self.proxy.set_addons(Stream()) self.setup_connection() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'server-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'self.client-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) @@ -204,19 +204,19 @@ class TestSimple(_WebSocketTest): self.proxy.set_addons(Addon()) self.setup_connection() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'foo' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'foo' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'foo' @@ -236,7 +236,7 @@ class TestKillFlow(_WebSocketTest): self.setup_connection() with pytest.raises(exceptions.TcpDisconnect): - _, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile, False) class TestSimpleTLS(_WebSocketTest): @@ -247,20 +247,20 @@ class TestSimpleTLS(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar'))) wfile.flush() - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): self.setup_connection() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'server-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame.payload == b'self.client-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) @@ -274,7 +274,7 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PING, payload=b'foobar'))) wfile.flush() - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) assert header.opcode == Opcode.PONG assert frame.payload == b'foobar' @@ -283,14 +283,14 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() - _, _, _ = websocket_utils.read_raw_frame(rfile) + _ = websocket.read_frame(rfile, False) @pytest.mark.asyncio async def test_ping(self): self.setup_connection() - header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) - _ = websocket_utils.read_raw_frame(self.client.rfile) + header, frame, _ = websocket.read_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile, False) self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() assert header.opcode == Opcode.PING @@ -303,7 +303,7 @@ class TestPong(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) assert header.opcode == Opcode.PING assert frame.payload == b'' @@ -312,7 +312,7 @@ class TestPong(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() - _ = websocket_utils.read_raw_frame(rfile) + _ = websocket.read_frame(rfile) @pytest.mark.asyncio async def test_pong(self): @@ -321,8 +321,8 @@ class TestPong(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.PING, payload=b'foobar'))) self.client.wfile.flush() - header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) - _ = websocket_utils.read_raw_frame(self.client.rfile) + header, frame, _ = websocket.read_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() @@ -335,13 +335,13 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() with pytest.raises(exceptions.TcpDisconnect): - _, _, _ = websocket_utils.read_raw_frame(rfile) + _ = websocket.read_frame(rfile) def test_close(self): self.setup_connection() @@ -349,9 +349,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) def test_close_payload_1(self): self.setup_connection() @@ -359,9 +359,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) def test_close_payload_2(self): self.setup_connection() @@ -369,9 +369,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket.read_frame(self.client.rfile) class TestInvalidFrame(_WebSocketTest): @@ -384,7 +384,7 @@ class TestInvalidFrame(_WebSocketTest): def test_invalid_frame(self): self.setup_connection() - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) code, = struct.unpack('!H', frame.payload[:2]) assert code == 1002 assert frame.payload[2:].startswith(b'Invalid opcode') @@ -409,11 +409,11 @@ class TestStreaming(_WebSocketTest): frame = None if not streaming: with pytest.raises(exceptions.TcpDisconnect): # Reader.safe_read get nothing as result - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame is None else: - _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _, frame, _ = websocket.read_frame(self.client.rfile) assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received @@ -426,12 +426,12 @@ class TestExtension(_WebSocketTest): wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') wfile.flush() - header, _, _ = websocket_utils.read_raw_frame(rfile) + header, _, _ = websocket.read_frame(rfile) assert header.rsv.rsv1 wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') wfile.flush() - header, _, _ = websocket_utils.read_raw_frame(rfile) + header, _, _ = websocket.read_frame(rfile) assert header.rsv.rsv1 wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') wfile.flush() @@ -439,19 +439,19 @@ class TestExtension(_WebSocketTest): def test_extension(self): self.setup_connection(True) - header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + header, _, _ = websocket.read_frame(self.client.rfile) assert header.rsv.rsv1 self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') self.client.wfile.flush() - header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + header, _, _ = websocket.read_frame(self.client.rfile) assert header.rsv.rsv1 self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') self.client.wfile.flush() - header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + header, _, _ = websocket.read_frame(self.client.rfile) assert header.rsv.rsv1 assert len(self.master.state.flows[1].messages) == 5 @@ -481,7 +481,7 @@ class TestInjectMessageClient(_WebSocketTest): self.proxy.set_addons(Inject()) self.setup_connection() - header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + header, frame, _ = websocket.read_frame(self.client.rfile) assert header.opcode == Opcode.TEXT assert frame.payload == b'This is an injected message!' @@ -490,7 +490,7 @@ class TestInjectMessageServer(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - header, frame, _ = websocket_utils.read_raw_frame(rfile) + header, frame, _ = websocket.read_frame(rfile) assert header.opcode == Opcode.TEXT success = frame.payload == b'This is an injected message!' @@ -505,6 +505,6 @@ class TestInjectMessageServer(_WebSocketTest): self.proxy.set_addons(Inject()) self.setup_connection() - header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + header, frame, _ = websocket.read_frame(self.client.rfile) assert header.opcode == Opcode.TEXT assert frame.payload == b'True' diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py index 63a13c881..b0fffe731 100644 --- a/test/pathod/protocols/test_http2.py +++ b/test/pathod/protocols/test_http2.py @@ -106,12 +106,12 @@ class TestPerformServerConnectionPreface(net_tservers.ServerTestBase): self.wfile.flush() # check empty settings frame - raw = http2.read_raw_frame(self.rfile) - assert raw == bytes.fromhex("00000c040000000000000200000000000300000001") + _, consumed_bytes = http2.read_frame(self.rfile, False) + assert consumed_bytes == bytes.fromhex("00000c040000000000000200000000000300000001") # check settings acknowledgement - raw = http2.read_raw_frame(self.rfile) - assert raw == bytes.fromhex("000000040100000000") + _, consumed_bytes = http2.read_frame(self.rfile, False) + assert consumed_bytes == bytes.fromhex("000000040100000000") # send settings acknowledgement self.wfile.write(bytes.fromhex("000000040100000000")) @@ -126,7 +126,7 @@ class TestPerformServerConnectionPreface(net_tservers.ServerTestBase): protocol.perform_server_connection_preface() assert protocol.connection_preface_performed - with pytest.raises(exceptions.TcpDisconnect): + with pytest.raises(exceptions.TcpReadIncomplete): protocol.perform_server_connection_preface(force=True)