alpn: str -> bytes
This commit is contained in:
parent
8ac5af62f5
commit
dfba6e81a6
|
@ -8,7 +8,6 @@ from mitmproxy.net import tls as net_tls
|
|||
from mitmproxy.options import CONF_BASENAME
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy.layers import tls
|
||||
from mitmproxy.utils.strutils import always_bytes
|
||||
|
||||
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
|
||||
# https://ssl-config.mozilla.org/#config=old
|
||||
|
@ -36,7 +35,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any:
|
|||
return server_alpn
|
||||
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
|
||||
for alpn in options: # client sends in order of preference, so we are nice and respect that.
|
||||
if alpn.decode(errors="replace") in http_alpns:
|
||||
if alpn in http_alpns:
|
||||
return alpn
|
||||
else:
|
||||
return SSL.NO_OVERLAPPING_PROTOCOLS
|
||||
|
@ -138,7 +137,7 @@ class TlsConfig:
|
|||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_app_data(AppData(
|
||||
server_alpn=always_bytes(server.alpn, "utf8", "replace"),
|
||||
server_alpn=server.alpn,
|
||||
http2=ctx.options.http2,
|
||||
))
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
|
@ -155,14 +154,13 @@ class TlsConfig:
|
|||
|
||||
if server.sni is True:
|
||||
server.sni = client.sni or server.address[0]
|
||||
sni: Optional[bytes] = server.sni.encode("ascii") if server.sni else None
|
||||
|
||||
if not server.alpn_offers:
|
||||
if client.alpn_offers:
|
||||
if ctx.options.http2:
|
||||
server.alpn_offers = tuple(client.alpn_offers)
|
||||
else:
|
||||
server.alpn_offers = tuple(x for x in client.alpn_offers if x != "h2")
|
||||
server.alpn_offers = tuple(x for x in client.alpn_offers if x != b"h2")
|
||||
elif client.tls_established:
|
||||
# We would perfectly support HTTP/1 -> HTTP/2, but we want to keep things on the same protocol version.
|
||||
# There are some edge cases where we want to mirror the regular server's behavior accurately,
|
||||
|
@ -172,7 +170,6 @@ class TlsConfig:
|
|||
server.alpn_offers = tls.HTTP_ALPNS
|
||||
else:
|
||||
server.alpn_offers = tls.HTTP1_ALPNS
|
||||
alpn_offers: List[bytes] = [alpn.encode() for alpn in server.alpn_offers]
|
||||
|
||||
if not server.cipher_list and ctx.options.ciphers_server:
|
||||
server.cipher_list = ctx.options.ciphers_server.split(":")
|
||||
|
@ -195,15 +192,16 @@ class TlsConfig:
|
|||
max_version=net_tls.Version[ctx.options.tls_version_client_max],
|
||||
cipher_list=cipher_list,
|
||||
verify=verify,
|
||||
sni=sni,
|
||||
sni=server.sni,
|
||||
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
|
||||
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
|
||||
client_cert=client_cert,
|
||||
alpn_protos=alpn_offers,
|
||||
alpn_protos=server.alpn_offers,
|
||||
)
|
||||
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(sni)
|
||||
if server.sni:
|
||||
tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode())
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
|
||||
def running(self):
|
||||
|
|
|
@ -235,11 +235,9 @@ def convert_10_11(data):
|
|||
|
||||
def conv_conn(conn):
|
||||
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
|
||||
conn["alpn"] = strutils.always_str(conn.pop("alpn_proto_negotiated"), "utf8", "backslashreplace")
|
||||
conn["alpn_offers"] = [
|
||||
strutils.always_str(alpn, "utf8", "backslashreplace")
|
||||
for alpn in (conn["alpn_offers"] or [])
|
||||
]
|
||||
conn["alpn"] = conn.pop("alpn_proto_negotiated")
|
||||
conn["alpn_offers"] = conn["alpn_offers"] or []
|
||||
conn["cipher_list"] = conn["cipher_list"] or []
|
||||
|
||||
conv_conn(data["client_conn"])
|
||||
conv_conn(data["server_conn"])
|
||||
|
|
|
@ -130,7 +130,7 @@ def create_proxy_server_context(
|
|||
max_version: Version,
|
||||
cipher_list: Optional[Iterable[str]],
|
||||
verify: Verify,
|
||||
sni: Optional[bytes],
|
||||
sni: Optional[str],
|
||||
ca_path: Optional[str],
|
||||
ca_pemfile: Optional[str],
|
||||
client_cert: Optional[str],
|
||||
|
@ -148,6 +148,7 @@ def create_proxy_server_context(
|
|||
|
||||
context.set_verify(verify.value, None)
|
||||
if sni is not None:
|
||||
assert isinstance(sni, str)
|
||||
# Manually enable hostname verification on the context object.
|
||||
# https://wiki.openssl.org/index.php/Hostname_validation
|
||||
param = SSL._lib.SSL_CTX_get0_param(context._context)
|
||||
|
@ -158,7 +159,7 @@ def create_proxy_server_context(
|
|||
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
|
||||
)
|
||||
SSL._openssl_assert(
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1
|
||||
)
|
||||
|
||||
if ca_path is None and ca_pemfile is None:
|
||||
|
@ -293,14 +294,11 @@ class ClientHello:
|
|||
return None
|
||||
|
||||
@property
|
||||
def alpn_protocols(self) -> List[str]:
|
||||
def alpn_protocols(self) -> List[bytes]:
|
||||
if self._client_hello.extensions:
|
||||
for extension in self._client_hello.extensions.extensions:
|
||||
if extension.type == 0x10:
|
||||
try:
|
||||
return [x.name.decode() for x in extension.body.alpn_protocols]
|
||||
except UnicodeDecodeError:
|
||||
return []
|
||||
return list(x.name for x in extension.body.alpn_protocols)
|
||||
return []
|
||||
|
||||
@property
|
||||
|
|
|
@ -56,8 +56,8 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
|||
TLS version, with the exception of the end-entity certificate which
|
||||
MUST be first.
|
||||
"""
|
||||
alpn: Optional[str] = None
|
||||
alpn_offers: Sequence[str] = ()
|
||||
alpn: Optional[bytes] = None
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
|
||||
# we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography
|
||||
cipher: Optional[str] = None
|
||||
|
@ -98,9 +98,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
|||
@property
|
||||
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
|
||||
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", DeprecationWarning)
|
||||
if self.alpn is not None:
|
||||
return self.alpn.encode()
|
||||
return None
|
||||
return self.alpn
|
||||
|
||||
|
||||
class Client(Connection):
|
||||
|
|
|
@ -351,7 +351,7 @@ class HttpStream(layer.Layer):
|
|||
yield HttpErrorHook(self.flow)
|
||||
# For HTTP/2 we only want to kill the specific stream, for HTTP/1 we want to kill the connection
|
||||
# *without* sending an HTTP response (that could be achieved by the user by setting flow.response).
|
||||
if self.context.client.alpn == "h2":
|
||||
if self.context.client.alpn == b"h2":
|
||||
yield SendHttp(ResponseProtocolError(self.stream_id, "killed"), self.context.client)
|
||||
else:
|
||||
if self.context.client.state & ConnectionState.CAN_WRITE:
|
||||
|
@ -532,7 +532,7 @@ class HttpLayer(layer.Layer):
|
|||
self.command_sources = {}
|
||||
|
||||
http_conn: HttpConnection
|
||||
if self.context.client.alpn == "h2":
|
||||
if self.context.client.alpn == b"h2":
|
||||
http_conn = Http2Server(context.fork())
|
||||
else:
|
||||
http_conn = Http1Server(context.fork())
|
||||
|
@ -606,10 +606,10 @@ class HttpLayer(layer.Layer):
|
|||
for connection in self.connections:
|
||||
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
||||
conn_is_pending_or_h2 = (
|
||||
connection.alpn == "h2"
|
||||
connection.alpn == b"h2"
|
||||
or connection in self.waiting_for_establishment
|
||||
)
|
||||
h2_to_h1 = self.context.client.alpn == "h2" and not conn_is_pending_or_h2
|
||||
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
|
||||
connection_suitable = (
|
||||
event.connection_spec_matches(connection)
|
||||
and not h2_to_h1
|
||||
|
@ -679,7 +679,7 @@ class HttpLayer(layer.Layer):
|
|||
# that neither have a content-length specified nor a chunked transfer encoding.
|
||||
# We can't process these two flows to the same h1 connection as they would both have
|
||||
# "read until eof" semantics. The only workaround left is to open a separate connection for each flow.
|
||||
if not command.err and self.context.client.alpn == "h2" and command.connection.alpn != "h2":
|
||||
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
|
||||
for cmd in waiting[1:]:
|
||||
yield from self.get_connection(cmd, reuse=False)
|
||||
break
|
||||
|
@ -695,7 +695,7 @@ class HttpClient(layer.Layer):
|
|||
err = yield commands.OpenConnection(self.context.server)
|
||||
if not err:
|
||||
child_layer: layer.Layer
|
||||
if self.context.server.alpn == "h2":
|
||||
if self.context.server.alpn == b"h2":
|
||||
child_layer = Http2Client(self.context)
|
||||
else:
|
||||
child_layer = Http1Client(self.context)
|
||||
|
|
|
@ -91,8 +91,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
|||
return None
|
||||
|
||||
|
||||
HTTP1_ALPNS = ("http/1.1", "http/1.0", "http/0.9")
|
||||
HTTP_ALPNS = ("h2",) + HTTP1_ALPNS
|
||||
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
|
||||
|
||||
|
||||
# We need these classes as hooks can only have one argument at the moment.
|
||||
|
@ -196,7 +196,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
|||
all_certs.insert(0, cert)
|
||||
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated().decode()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
|
|
|
@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
|
|||
timestamp_end=946681206,
|
||||
sni="address",
|
||||
cipher_name="cipher",
|
||||
alpn="http/1.1",
|
||||
alpn=b"http/1.1",
|
||||
tls_version="TLSv1.2",
|
||||
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
|
||||
state=0,
|
||||
|
|
|
@ -4,7 +4,7 @@ import urwid
|
|||
import mitmproxy.flow
|
||||
from mitmproxy import http
|
||||
from mitmproxy.tools.console import common, searchable
|
||||
from mitmproxy.utils import human
|
||||
from mitmproxy.utils import human, strutils
|
||||
|
||||
|
||||
def maybe_timestamp(base, attr):
|
||||
|
@ -49,7 +49,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
|||
if resp:
|
||||
parts.append(("HTTP Version", resp.http_version))
|
||||
if sc.alpn:
|
||||
parts.append(("ALPN", sc.alpn))
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn)))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
|
@ -69,7 +69,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
|||
]
|
||||
|
||||
if c.altnames:
|
||||
parts.append(("Alt names", ", ".join(c.altnames)))
|
||||
parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)))
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
)
|
||||
|
@ -89,7 +89,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
|||
if cc.cipher:
|
||||
parts.append(("Cipher Name", cc.cipher))
|
||||
if cc.alpn:
|
||||
parts.append(("ALPN", cc.alpn))
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn)))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
|
|
|
@ -20,6 +20,7 @@ from mitmproxy import io
|
|||
from mitmproxy import log
|
||||
from mitmproxy import optmanager
|
||||
from mitmproxy import version
|
||||
from mitmproxy.utils.strutils import always_str
|
||||
|
||||
|
||||
def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
|
@ -48,7 +49,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
|||
"timestamp_end": flow.client_conn.timestamp_end,
|
||||
"sni": flow.client_conn.sni,
|
||||
"cipher_name": flow.client_conn.cipher,
|
||||
"alpn_proto_negotiated": flow.client_conn.alpn,
|
||||
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
|
||||
"tls_version": flow.client_conn.tls_version,
|
||||
}
|
||||
|
||||
|
@ -60,7 +61,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
|||
"source_address": flow.server_conn.sockname,
|
||||
"tls_established": flow.server_conn.tls_established,
|
||||
"sni": flow.server_conn.sni,
|
||||
"alpn_proto_negotiated": flow.server_conn.alpn,
|
||||
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
|
||||
"tls_version": flow.server_conn.tls_version,
|
||||
"timestamp_start": flow.server_conn.timestamp_start,
|
||||
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,
|
||||
|
|
|
@ -127,7 +127,7 @@ class TestTlsConfig:
|
|||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
|
||||
ctx.client.alpn_offers = ["h2"]
|
||||
ctx.client.alpn_offers = [b"h2"]
|
||||
ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"]
|
||||
ctx.server.address = ("example.mitmproxy.org", 443)
|
||||
|
||||
|
@ -185,8 +185,8 @@ class TestTlsConfig:
|
|||
ta.tls_start(tls_start)
|
||||
assert ctx.server.alpn_offers == expected
|
||||
|
||||
assert_alpn(True, tls.HTTP_ALPNS + ("foo",), tls.HTTP_ALPNS + ("foo",))
|
||||
assert_alpn(False, tls.HTTP_ALPNS + ("foo",), tls.HTTP1_ALPNS + ("foo",))
|
||||
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
|
||||
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
|
||||
assert_alpn(True, [], tls.HTTP_ALPNS)
|
||||
assert_alpn(False, [], tls.HTTP1_ALPNS)
|
||||
ctx.client.timestamp_tls_setup = time.time()
|
||||
|
|
|
@ -110,7 +110,7 @@ class TestClientHello:
|
|||
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
|
||||
49171, 49162, 49172, 156, 157, 47, 53, 10
|
||||
]
|
||||
assert c.alpn_protocols == ['h2', 'http/1.1']
|
||||
assert c.alpn_protocols == [b'h2', b'http/1.1']
|
||||
assert c.extensions == [
|
||||
(65281, b'\x00'),
|
||||
(0, b'\x00\x0e\x00\x00\x0bexample.com'),
|
||||
|
|
|
@ -44,7 +44,7 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
|
|||
|
||||
|
||||
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
|
||||
tctx.client.alpn = "h2"
|
||||
tctx.client.alpn = b"h2"
|
||||
frame_factory = FrameFactory()
|
||||
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
|
@ -58,7 +58,7 @@ def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
|
|||
|
||||
|
||||
def make_h2(open_connection: OpenConnection) -> None:
|
||||
open_connection.connection.alpn = "h2"
|
||||
open_connection.connection.alpn = b"h2"
|
||||
|
||||
|
||||
def test_simple(tctx):
|
||||
|
|
|
@ -208,7 +208,7 @@ def h2_frames(draw):
|
|||
|
||||
def h2_layer(opts):
|
||||
tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts)
|
||||
tctx.client.alpn = "h2"
|
||||
tctx.client.alpn = b"h2"
|
||||
|
||||
layer = http.HttpLayer(tctx, HTTPMode.regular)
|
||||
for _ in layer.handle_event(Start()):
|
||||
|
|
|
@ -22,7 +22,7 @@ def event_types(events):
|
|||
|
||||
|
||||
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
|
||||
tctx.client.alpn = "h2"
|
||||
tctx.client.alpn = b"h2"
|
||||
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
conn = h2.connection.H2Connection()
|
||||
|
|
|
@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
|||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
# Set SNI
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode("ascii"))
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode())
|
||||
|
||||
# Manually enable hostname verification.
|
||||
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
|
||||
|
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
|||
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
|
||||
)
|
||||
SSL._openssl_assert(
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode("ascii"), 0) == 1
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1
|
||||
)
|
||||
|
||||
return tutils.reply(*args, side_effect=make_conn, **kwargs)
|
||||
|
@ -446,8 +446,8 @@ class TestClientTLS:
|
|||
assert tctx.client.tls_established
|
||||
assert tctx.server.tls_established
|
||||
assert tctx.server.sni == tctx.client.sni
|
||||
assert tctx.client.alpn == "quux"
|
||||
assert tctx.server.alpn == "quux"
|
||||
assert tctx.client.alpn == b"quux"
|
||||
assert tctx.server.alpn == b"quux"
|
||||
_test_echo(playbook, tssl_server, tctx.server)
|
||||
_test_echo(playbook, tssl_client, tctx.client)
|
||||
|
||||
|
|
Loading…
Reference in New Issue