diff --git a/mitmproxy/connection.py b/mitmproxy/connection.py index c0ca423f2..cbb732f54 100644 --- a/mitmproxy/connection.py +++ b/mitmproxy/connection.py @@ -293,6 +293,11 @@ class Server(Connection): local_port = "" return f"Server({human.format_address(self.address)}, state={self.state.name.lower()}{tls_state}{local_port})" + def __setattr__(self, name, value): + if name == "address" and self.__dict__.get("state", ConnectionState.CLOSED) is ConnectionState.OPEN: + raise RuntimeError("Cannot change server address on open connection.") + return super().__setattr__(name, value) + def get_state(self): return { 'address': self.address, diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index 1fd77a1d0..08f484b33 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -226,8 +226,8 @@ class TestServerTLS: def test_simple(self, tctx): playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) - tctx.server.state = ConnectionState.OPEN tctx.server.address = ("example.mitmproxy.org", 443) + tctx.server.state = ConnectionState.OPEN tctx.server.sni = "example.mitmproxy.org" tssl = SSLTest(server_side=True) @@ -345,7 +345,7 @@ def make_client_tls_layer( playbook = tutils.Playbook(server_layer) # Add some server config, this is needed anyways. - tctx.server.address = ("example.mitmproxy.org", 443) + tctx.server.__dict__["address"] = ("example.mitmproxy.org", 443) # .address fails because connection is open tctx.server.sni = "example.mitmproxy.org" tssl_client = SSLTest(**kwargs) diff --git a/test/mitmproxy/test_connection.py b/test/mitmproxy/test_connection.py index fe04983ec..bf685a361 100644 --- a/test/mitmproxy/test_connection.py +++ b/test/mitmproxy/test_connection.py @@ -1,3 +1,5 @@ +import pytest + from mitmproxy.connection import Server, Client, ConnectionState from mitmproxy.test.tflow import tclient_conn, tserver_conn @@ -76,3 +78,10 @@ class TestServer: assert c2.get_state() != c.get_state() c.id = c2.id = "foo" assert c2.get_state() == c.get_state() + + def test_address(self): + s = Server(("address", 22)) + s.address = ("example.com", 443) + s.state = ConnectionState.OPEN + with pytest.raises(RuntimeError): + s.address = ("example.com", 80)