raise if server address is updated on non-closed connections

This commit is contained in:
Maximilian Hils 2021-06-19 12:28:40 +02:00
parent 15adb2cd79
commit 6178b4b72a
3 changed files with 16 additions and 2 deletions

View File

@ -293,6 +293,11 @@ class Server(Connection):
local_port = "" local_port = ""
return f"Server({human.format_address(self.address)}, state={self.state.name.lower()}{tls_state}{local_port})" return f"Server({human.format_address(self.address)}, state={self.state.name.lower()}{tls_state}{local_port})"
def __setattr__(self, name, value):
if name == "address" and self.__dict__.get("state", ConnectionState.CLOSED) is ConnectionState.OPEN:
raise RuntimeError("Cannot change server address on open connection.")
return super().__setattr__(name, value)
def get_state(self): def get_state(self):
return { return {
'address': self.address, 'address': self.address,

View File

@ -226,8 +226,8 @@ class TestServerTLS:
def test_simple(self, tctx): def test_simple(self, tctx):
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.state = ConnectionState.OPEN
tctx.server.address = ("example.mitmproxy.org", 443) tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.state = ConnectionState.OPEN
tctx.server.sni = "example.mitmproxy.org" tctx.server.sni = "example.mitmproxy.org"
tssl = SSLTest(server_side=True) tssl = SSLTest(server_side=True)
@ -345,7 +345,7 @@ def make_client_tls_layer(
playbook = tutils.Playbook(server_layer) playbook = tutils.Playbook(server_layer)
# Add some server config, this is needed anyways. # Add some server config, this is needed anyways.
tctx.server.address = ("example.mitmproxy.org", 443) tctx.server.__dict__["address"] = ("example.mitmproxy.org", 443) # .address fails because connection is open
tctx.server.sni = "example.mitmproxy.org" tctx.server.sni = "example.mitmproxy.org"
tssl_client = SSLTest(**kwargs) tssl_client = SSLTest(**kwargs)

View File

@ -1,3 +1,5 @@
import pytest
from mitmproxy.connection import Server, Client, ConnectionState from mitmproxy.connection import Server, Client, ConnectionState
from mitmproxy.test.tflow import tclient_conn, tserver_conn from mitmproxy.test.tflow import tclient_conn, tserver_conn
@ -76,3 +78,10 @@ class TestServer:
assert c2.get_state() != c.get_state() assert c2.get_state() != c.get_state()
c.id = c2.id = "foo" c.id = c2.id = "foo"
assert c2.get_state() == c.get_state() assert c2.get_state() == c.get_state()
def test_address(self):
s = Server(("address", 22))
s.address = ("example.com", 443)
s.state = ConnectionState.OPEN
with pytest.raises(RuntimeError):
s.address = ("example.com", 80)