Clean up interfaces by making some methods pseudo-private.

This commit is contained in:
Aldo Cortesi 2011-08-03 22:48:40 +12:00
parent 57c653be5f
commit 9042d3f3b9
2 changed files with 74 additions and 71 deletions

View File

@ -62,11 +62,11 @@ class Headers:
def add(self, key, value): def add(self, key, value):
self.lst.append([key, str(value)]) self.lst.append([key, str(value)])
def get_state(self): def _get_state(self):
return [tuple(i) for i in self.lst] return [tuple(i) for i in self.lst]
@classmethod @classmethod
def from_state(klass, state): def _from_state(klass, state):
return klass([list(i) for i in state]) return klass([list(i) for i in state])
def copy(self): def copy(self):
@ -85,7 +85,10 @@ class Headers:
def match_re(self, expr): def match_re(self, expr):
""" """
Match the regular expression against each header (key, value) pair. Match the regular expression against each header. For each (key,
value) pair a string of the following format is matched against:
"key: value"
""" """
for k, v in self.lst: for k, v in self.lst:
s = "%s: %s"%(k, v) s = "%s: %s"%(k, v)
@ -211,12 +214,12 @@ class Request(HTTPMsg):
else: else:
return True return True
def load_state(self, state): def _load_state(self, state):
if state["client_conn"]: if state["client_conn"]:
if self.client_conn: if self.client_conn:
self.client_conn.load_state(state["client_conn"]) self.client_conn._load_state(state["client_conn"])
else: else:
self.client_conn = ClientConnect.from_state(state["client_conn"]) self.client_conn = ClientConnect._from_state(state["client_conn"])
else: else:
self.client_conn = None self.client_conn = None
self.host = state["host"] self.host = state["host"]
@ -224,33 +227,33 @@ class Request(HTTPMsg):
self.scheme = state["scheme"] self.scheme = state["scheme"]
self.method = state["method"] self.method = state["method"]
self.path = state["path"] self.path = state["path"]
self.headers = Headers.from_state(state["headers"]) self.headers = Headers._from_state(state["headers"])
self.content = base64.decodestring(state["content"]) self.content = base64.decodestring(state["content"])
self.timestamp = state["timestamp"] self.timestamp = state["timestamp"]
def get_state(self): def _get_state(self):
return dict( return dict(
client_conn = self.client_conn.get_state() if self.client_conn else None, client_conn = self.client_conn._get_state() if self.client_conn else None,
host = self.host, host = self.host,
port = self.port, port = self.port,
scheme = self.scheme, scheme = self.scheme,
method = self.method, method = self.method,
path = self.path, path = self.path,
headers = self.headers.get_state(), headers = self.headers._get_state(),
content = base64.encodestring(self.content), content = base64.encodestring(self.content),
timestamp = self.timestamp, timestamp = self.timestamp,
) )
@classmethod @classmethod
def from_state(klass, state): def _from_state(klass, state):
return klass( return klass(
ClientConnect.from_state(state["client_conn"]), ClientConnect._from_state(state["client_conn"]),
str(state["host"]), str(state["host"]),
state["port"], state["port"],
str(state["scheme"]), str(state["scheme"]),
str(state["method"]), str(state["method"]),
str(state["path"]), str(state["path"]),
Headers.from_state(state["headers"]), Headers._from_state(state["headers"]),
base64.decodestring(state["content"]), base64.decodestring(state["content"]),
state["timestamp"] state["timestamp"]
) )
@ -259,7 +262,7 @@ class Request(HTTPMsg):
return id(self) return id(self)
def __eq__(self, other): def __eq__(self, other):
return self.get_state() == other.get_state() return self._get_state() == other._get_state()
def copy(self): def copy(self):
c = copy.copy(self) c = copy.copy(self)
@ -395,35 +398,35 @@ class Response(HTTPMsg):
def is_replay(self): def is_replay(self):
return self.replay return self.replay
def load_state(self, state): def _load_state(self, state):
self.code = state["code"] self.code = state["code"]
self.msg = state["msg"] self.msg = state["msg"]
self.headers = Headers.from_state(state["headers"]) self.headers = Headers._from_state(state["headers"])
self.content = base64.decodestring(state["content"]) self.content = base64.decodestring(state["content"])
self.timestamp = state["timestamp"] self.timestamp = state["timestamp"]
def get_state(self): def _get_state(self):
return dict( return dict(
code = self.code, code = self.code,
msg = self.msg, msg = self.msg,
headers = self.headers.get_state(), headers = self.headers._get_state(),
timestamp = self.timestamp, timestamp = self.timestamp,
content = base64.encodestring(self.content) content = base64.encodestring(self.content)
) )
@classmethod @classmethod
def from_state(klass, request, state): def _from_state(klass, request, state):
return klass( return klass(
request, request,
state["code"], state["code"],
str(state["msg"]), str(state["msg"]),
Headers.from_state(state["headers"]), Headers._from_state(state["headers"]),
base64.decodestring(state["content"]), base64.decodestring(state["content"]),
state["timestamp"], state["timestamp"],
) )
def __eq__(self, other): def __eq__(self, other):
return self.get_state() == other.get_state() return self._get_state() == other._get_state()
def copy(self): def copy(self):
c = copy.copy(self) c = copy.copy(self)
@ -484,16 +487,16 @@ class ClientConnect(controller.Msg):
controller.Msg.__init__(self) controller.Msg.__init__(self)
def __eq__(self, other): def __eq__(self, other):
return self.get_state() == other.get_state() return self._get_state() == other._get_state()
def load_state(self, state): def _load_state(self, state):
self.address = state self.address = state
def get_state(self): def _get_state(self):
return list(self.address) if self.address else None return list(self.address) if self.address else None
@classmethod @classmethod
def from_state(klass, state): def _from_state(klass, state):
if state: if state:
return klass(state) return klass(state)
else: else:
@ -509,21 +512,21 @@ class Error(controller.Msg):
self.timestamp = timestamp or utils.timestamp() self.timestamp = timestamp or utils.timestamp()
controller.Msg.__init__(self) controller.Msg.__init__(self)
def load_state(self, state): def _load_state(self, state):
self.msg = state["msg"] self.msg = state["msg"]
self.timestamp = state["timestamp"] self.timestamp = state["timestamp"]
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
def get_state(self): def _get_state(self):
return dict( return dict(
msg = self.msg, msg = self.msg,
timestamp = self.timestamp, timestamp = self.timestamp,
) )
@classmethod @classmethod
def from_state(klass, state): def _from_state(klass, state):
return klass( return klass(
None, None,
state["msg"], state["msg"],
@ -531,7 +534,7 @@ class Error(controller.Msg):
) )
def __eq__(self, other): def __eq__(self, other):
return self.get_state() == other.get_state() return self._get_state() == other._get_state()
def replace(self, pattern, repl, *args, **kwargs): def replace(self, pattern, repl, *args, **kwargs):
""" """
@ -708,9 +711,9 @@ class Flow:
self._backup = None self._backup = None
@classmethod @classmethod
def from_state(klass, state): def _from_state(klass, state):
f = klass(None) f = klass(None)
f.load_state(state) f._load_state(state)
return f return f
@classmethod @classmethod
@ -719,13 +722,13 @@ class Flow:
data = json.loads(data) data = json.loads(data)
except Exception: except Exception:
return None return None
return klass.from_state(data) return klass._from_state(data)
def get_state(self, nobackup=False): def _get_state(self, nobackup=False):
d = dict( d = dict(
request = self.request.get_state() if self.request else None, request = self.request._get_state() if self.request else None,
response = self.response.get_state() if self.response else None, response = self.response._get_state() if self.response else None,
error = self.error.get_state() if self.error else None, error = self.error._get_state() if self.error else None,
version = version.IVERSION version = version.IVERSION
) )
if nobackup: if nobackup:
@ -734,26 +737,26 @@ class Flow:
d["backup"] = self._backup d["backup"] = self._backup
return d return d
def load_state(self, state): def _load_state(self, state):
self._backup = state["backup"] self._backup = state["backup"]
if self.request: if self.request:
self.request.load_state(state["request"]) self.request._load_state(state["request"])
else: else:
self.request = Request.from_state(state["request"]) self.request = Request._from_state(state["request"])
if state["response"]: if state["response"]:
if self.response: if self.response:
self.response.load_state(state["response"]) self.response._load_state(state["response"])
else: else:
self.response = Response.from_state(self.request, state["response"]) self.response = Response._from_state(self.request, state["response"])
else: else:
self.response = None self.response = None
if state["error"]: if state["error"]:
if self.error: if self.error:
self.error.load_state(state["error"]) self.error._load_state(state["error"])
else: else:
self.error = Error.from_state(state["error"]) self.error = Error._from_state(state["error"])
else: else:
self.error = None self.error = None
@ -766,11 +769,11 @@ class Flow:
return False return False
def backup(self): def backup(self):
self._backup = self.get_state(nobackup=True) self._backup = self._get_state(nobackup=True)
def revert(self): def revert(self):
if self._backup: if self._backup:
self.load_state(self._backup) self._load_state(self._backup)
self._backup = None self._backup = None
def match(self, pattern): def match(self, pattern):
@ -1041,7 +1044,7 @@ class FlowMaster(controller.Master):
rflow = self.server_playback.next_flow(flow) rflow = self.server_playback.next_flow(flow)
if not rflow: if not rflow:
return None return None
response = Response.from_state(flow.request, rflow.response.get_state()) response = Response._from_state(flow.request, rflow.response._get_state())
response.set_replay() response.set_replay()
flow.response = response flow.response = response
if self.refresh_server_playback: if self.refresh_server_playback:
@ -1178,7 +1181,7 @@ class FlowWriter:
self.ns = netstring.FileEncoder(fo) self.ns = netstring.FileEncoder(fo)
def add(self, flow): def add(self, flow):
d = flow.get_state() d = flow._get_state()
s = json.dumps(d) s = json.dumps(d)
self.ns.write(s) self.ns.write(s)
@ -1201,7 +1204,7 @@ class FlowReader:
try: try:
for i in self.ns: for i in self.ns:
data = json.loads(i) data = json.loads(i)
yield Flow.from_state(data) yield Flow._from_state(data)
except netstring.DecoderError: except netstring.DecoderError:
raise FlowReadError("Invalid data format.") raise FlowReadError("Invalid data format.")

View File

@ -153,19 +153,19 @@ class uFlow(libpry.AutoTree):
def test_getset_state(self): def test_getset_state(self):
f = tutils.tflow() f = tutils.tflow()
f.response = tutils.tresp(f.request) f.response = tutils.tresp(f.request)
state = f.get_state() state = f._get_state()
assert f.get_state() == flow.Flow.from_state(state).get_state() assert f._get_state() == flow.Flow._from_state(state)._get_state()
f.response = None f.response = None
f.error = flow.Error(f.request, "error") f.error = flow.Error(f.request, "error")
state = f.get_state() state = f._get_state()
assert f.get_state() == flow.Flow.from_state(state).get_state() assert f._get_state() == flow.Flow._from_state(state)._get_state()
f2 = tutils.tflow() f2 = tutils.tflow()
f2.error = flow.Error(f.request, "e2") f2.error = flow.Error(f.request, "e2")
assert not f == f2 assert not f == f2
f.load_state(f2.get_state()) f._load_state(f2._get_state())
assert f.get_state() == f2.get_state() assert f._get_state() == f2._get_state()
def test_kill(self): def test_kill(self):
s = flow.State() s = flow.State()
@ -410,7 +410,7 @@ class uSerialize(libpry.AutoTree):
assert len(l) == 1 assert len(l) == 1
f2 = l[0] f2 = l[0]
assert f2.get_state() == f.get_state() assert f2._get_state() == f._get_state()
assert f2.request.assemble() == f.request.assemble() assert f2.request.assemble() == f.request.assemble()
def test_load_flows(self): def test_load_flows(self):
@ -594,20 +594,20 @@ class uRequest(libpry.AutoTree):
h["test"] = ["test"] h["test"] = ["test"]
c = flow.ClientConnect(("addr", 2222)) c = flow.ClientConnect(("addr", 2222))
r = flow.Request(c, "host", 22, "https", "GET", "/", h, "content") r = flow.Request(c, "host", 22, "https", "GET", "/", h, "content")
state = r.get_state() state = r._get_state()
assert flow.Request.from_state(state) == r assert flow.Request._from_state(state) == r
r.client_conn = None r.client_conn = None
state = r.get_state() state = r._get_state()
assert flow.Request.from_state(state) == r assert flow.Request._from_state(state) == r
r2 = flow.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test") r2 = flow.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test")
assert not r == r2 assert not r == r2
r.load_state(r2.get_state()) r._load_state(r2._get_state())
assert r == r2 assert r == r2
r2.client_conn = None r2.client_conn = None
r.load_state(r2.get_state()) r._load_state(r2._get_state())
assert not r.client_conn assert not r.client_conn
def test_replace(self): def test_replace(self):
@ -694,12 +694,12 @@ class uResponse(libpry.AutoTree):
req = flow.Request(c, "host", 22, "https", "GET", "/", h, "content") req = flow.Request(c, "host", 22, "https", "GET", "/", h, "content")
resp = flow.Response(req, 200, "msg", h.copy(), "content") resp = flow.Response(req, 200, "msg", h.copy(), "content")
state = resp.get_state() state = resp._get_state()
assert flow.Response.from_state(req, state) == resp assert flow.Response._from_state(req, state) == resp
resp2 = flow.Response(req, 220, "foo", h.copy(), "test") resp2 = flow.Response(req, 220, "foo", h.copy(), "test")
assert not resp == resp2 assert not resp == resp2
resp.load_state(resp2.get_state()) resp._load_state(resp2._get_state())
assert resp == resp2 assert resp == resp2
def test_replace(self): def test_replace(self):
@ -739,14 +739,14 @@ class uResponse(libpry.AutoTree):
class uError(libpry.AutoTree): class uError(libpry.AutoTree):
def test_getset_state(self): def test_getset_state(self):
e = flow.Error(None, "Error") e = flow.Error(None, "Error")
state = e.get_state() state = e._get_state()
assert flow.Error.from_state(state) == e assert flow.Error._from_state(state) == e
assert e.copy() assert e.copy()
e2 = flow.Error(None, "bar") e2 = flow.Error(None, "bar")
assert not e == e2 assert not e == e2
e.load_state(e2.get_state()) e._load_state(e2._get_state())
assert e == e2 assert e == e2
@ -762,12 +762,12 @@ class uError(libpry.AutoTree):
class uClientConnect(libpry.AutoTree): class uClientConnect(libpry.AutoTree):
def test_state(self): def test_state(self):
c = flow.ClientConnect(("a", 22)) c = flow.ClientConnect(("a", 22))
assert flow.ClientConnect.from_state(c.get_state()) == c assert flow.ClientConnect._from_state(c._get_state()) == c
c2 = flow.ClientConnect(("a", 25)) c2 = flow.ClientConnect(("a", 25))
assert not c == c2 assert not c == c2
c.load_state(c2.get_state()) c._load_state(c2._get_state())
assert c == c2 assert c == c2
c3 = c.copy() c3 = c.copy()
@ -851,8 +851,8 @@ class uHeaders(libpry.AutoTree):
self.hd.add("foo", 1) self.hd.add("foo", 1)
self.hd.add("foo", 2) self.hd.add("foo", 2)
self.hd.add("bar", 3) self.hd.add("bar", 3)
state = self.hd.get_state() state = self.hd._get_state()
nd = flow.Headers.from_state(state) nd = flow.Headers._from_state(state)
assert nd == self.hd assert nd == self.hd
def test_copy(self): def test_copy(self):