diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 18d4fc73e..c20b6ee27 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -24,14 +24,21 @@ DEFAULT_CIPHERS = ( class AppData(TypedDict): + client_alpn: Optional[bytes] server_alpn: Optional[bytes] http2: bool def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any: app_data: AppData = conn.get_app_data() + client_alpn = app_data["client_alpn"] server_alpn = app_data["server_alpn"] http2 = app_data["http2"] + if client_alpn is not None: + if client_alpn in options: + return client_alpn + else: + return SSL.NO_OVERLAPPING_PROTOCOLS if server_alpn and server_alpn in options: return server_alpn if server_alpn == b"": @@ -148,6 +155,7 @@ class TlsConfig: ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) tls_start.ssl_conn.set_app_data(AppData( + client_alpn=client.alpn, server_alpn=server.alpn, http2=ctx.options.http2, )) diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index 92237b67f..731530071 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -17,13 +17,19 @@ from test.mitmproxy.proxy.layers import test_tls def test_alpn_select_callback(): ctx = SSL.Context(SSL.SSLv23_METHOD) conn = SSL.Connection(ctx) - conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True)) + + # Test that we respect addons setting `client.alpn`. + conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=b"qux")) + assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == b"qux" + conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=b"")) + assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == SSL.NO_OVERLAPPING_PROTOCOLS # Test that we try to mirror the server connection's ALPN + conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=None)) assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == b"h2" # Test that we respect the client's preferred HTTP ALPN. - conn.set_app_data(tlsconfig.AppData(server_alpn=None, http2=True)) + conn.set_app_data(tlsconfig.AppData(server_alpn=None, http2=True, client_alpn=None)) assert tlsconfig.alpn_select_callback(conn, [b"qux", b"http/1.1", b"h2"]) == b"http/1.1" assert tlsconfig.alpn_select_callback(conn, [b"qux", b"h2", b"http/1.1"]) == b"h2" @@ -31,7 +37,7 @@ def test_alpn_select_callback(): assert tlsconfig.alpn_select_callback(conn, [b"qux", b"quux"]) == SSL.NO_OVERLAPPING_PROTOCOLS # Test that we don't select an ALPN if the server refused to select one. - conn.set_app_data(tlsconfig.AppData(server_alpn=b"", http2=True)) + conn.set_app_data(tlsconfig.AppData(server_alpn=b"", http2=True, client_alpn=None)) assert tlsconfig.alpn_select_callback(conn, [b"http/1.1"]) == SSL.NO_OVERLAPPING_PROTOCOLS