diff --git a/libmproxy/console.py b/libmproxy/console.py index d99dd8ac5..ffe37fc31 100644 --- a/libmproxy/console.py +++ b/libmproxy/console.py @@ -93,7 +93,7 @@ def format_flow(f, focus, extended=False, padding=2): txt.append(("goodcode", str(f.response.code))) else: txt.append(("error", str(f.response.code))) - t = f.response.headers.get("content-type") + t = f.response.headers["content-type"] if t: t = t[0].split(";")[0] txt.append(("text", " %s"%t)) @@ -295,7 +295,11 @@ class ConnectionView(WWrap): def _conn_text(self, conn, viewmode): if conn: - return self.master._cached_conn_text(conn.content, tuple(conn.headers.itemPairs()), viewmode) + return self.master._cached_conn_text( + conn.content, + tuple([tuple(i) for i in conn.headers.lst]), + viewmode + ) else: return urwid.ListBox([]) @@ -485,7 +489,7 @@ class ConnectionView(WWrap): else: conn = self.flow.response if conn.content: - t = conn.headers.get("content-type", [None]) + t = conn.headers["content-type"] or [None] t = t[0] self.master.spawn_external_viewer(conn.content, t) elif key == "b": diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 31c43581b..40cf7358b 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -79,7 +79,7 @@ class _Rex(_Action): raise ValueError, "Cannot compile expression." def _check_content_type(expr, o): - val = o.headers.get("content-type") + val = o.headers["content-type"] if val and re.search(expr, val[0]): return True return False diff --git a/libmproxy/flow.py b/libmproxy/flow.py index be77af337..d29b8e2d8 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -101,7 +101,7 @@ class ServerPlaybackState: if self.headers: hdrs = [] for i in self.headers: - v = r.headers.get(i, []) + v = r.headers[i] # Slightly subtle: we need to convert everything to strings # to prevent a mismatch between unicode/non-unicode. v = [str(x) for x in v] @@ -139,7 +139,7 @@ class StickyCookieState: ) def handle_response(self, f): - for i in f.response.headers.get("set-cookie", []): + for i in f.response.headers["set-cookie"]: # FIXME: We now know that Cookie.py screws up some cookies with # valid RFC 822/1123 datetime specifications for expiry. Sigh. c = Cookie.SimpleCookie(i) @@ -158,9 +158,10 @@ class StickyCookieState: f.request.path.startswith(i[2]) ] if all(match): - l = f.request.headers.setdefault("cookie", []) + l = f.request.headers["cookie"] f.request.stickycookie = True l.append(self.jar[i].output(header="").strip()) + f.request.headers["cookie"] = l class StickyAuthState: diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 362d622dd..690df9f44 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -56,11 +56,11 @@ def read_chunked(fp): def read_http_body(rfile, connection, headers, all): - if headers.has_key('transfer-encoding'): + if 'transfer-encoding' in headers: if not ",".join(headers["transfer-encoding"]) == "chunked": raise IOError('Invalid transfer-encoding') content = read_chunked(rfile) - elif headers.has_key("content-length"): + elif "content-length" in headers: content = rfile.read(int(headers["content-length"][0])) elif all: content = rfile.read() @@ -152,8 +152,7 @@ class Request(controller.Msg): "if-none-match", ] for i in delheaders: - if i in self.headers: - del self.headers[i] + del self.headers[i] def set_replay(self): self.client_conn = None @@ -251,7 +250,7 @@ class Request(controller.Msg): utils.try_del(headers, 'connection') utils.try_del(headers, 'content-length') utils.try_del(headers, 'transfer-encoding') - if not headers.has_key('host'): + if not 'host' in headers: headers["host"] = [self.hostport()] content = self.content if content is not None: @@ -321,7 +320,7 @@ class Response(controller.Msg): new = mktime_tz(d) + delta self.headers[i] = [formatdate(new)] c = [] - for i in self.headers.get("set-cookie", []): + for i in self.headers["set-cookie"]: c.append(self._refresh_cookie(i, delta)) if c: self.headers["set-cookie"] = c @@ -656,7 +655,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler): scheme = "https" headers = utils.Headers() headers.read(self.rfile) - if host is None and headers.has_key("host"): + if host is None and "host" in headers: netloc = headers["host"][0] if ':' in netloc: host, port = string.split(netloc, ':') @@ -670,7 +669,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler): port = int(port) if host is None: raise ProxyError(400, 'Invalid request: %s'%request) - if headers.has_key('expect'): + if "expect" in headers: expect = ",".join(headers['expect']) if expect == "100-continue" and httpminor >= 1: self.wfile.write('HTTP/1.1 100 Continue\r\n') @@ -681,7 +680,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler): raise ProxyError(417, 'Unmet expect: %s'%expect) if httpminor == 0: client_conn.close = True - if headers.has_key('connection'): + if "connection" in headers: for value in ",".join(headers['connection']).split(","): value = value.strip() if value == "close": diff --git a/libmproxy/utils.py b/libmproxy/utils.py index 8ac1f5470..38fc61078 100644 --- a/libmproxy/utils.py +++ b/libmproxy/utils.py @@ -12,7 +12,7 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import re, os, subprocess, datetime, textwrap, errno, sys, time, functools +import re, os, subprocess, datetime, textwrap, errno, sys, time, functools, copy import json CERT_SLEEP_TIME = 1 @@ -164,10 +164,6 @@ def isSequenceLike(anobj): return 1 -def _caseless(s): - return s.lower() - - def try_del(dict, key): try: del dict[key] @@ -175,108 +171,72 @@ def try_del(dict, key): pass -class MultiDict: - """ - Simple wrapper around a dictionary to make holding multiple objects per - key easier. +class Headers: + def __init__(self, lst=None): + if lst: + self.lst = lst + else: + self.lst = [] - Note that this class assumes that keys are strings. - - Keys have no order, but the order in which values are added to a key is - preserved. - """ - # This ridiculous bit of subterfuge is needed to prevent the class from - # treating this as a bound method. - _helper = (str,) - def __init__(self): - self._d = dict() - - def copy(self): - m = self.__class__() - m._d = self._d.copy() - return m - - def clear(self): - return self._d.clear() - - def get(self, key, d=None): - key = self._helper[0](key) - return self._d.get(key, d) - - def __contains__(self, key): - key = self._helper[0](key) - return self._d.__contains__(key) + def _kconv(self, s): + return s.lower() def __eq__(self, other): - return dict(self) == dict(other) + return self.lst == other.lst - def __delitem__(self, key): - self._d.__delitem__(key) + def __getitem__(self, k): + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret - def __getitem__(self, key): - key = self._helper[0](key) - return self._d.__getitem__(key) - - def __setitem__(self, key, value): - if not isSequenceLike(value): - raise ValueError, "Cannot insert non-sequence." - key = self._helper[0](key) - return self._d.__setitem__(key, value) + def _filter_lst(self, k, lst): + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new - def has_key(self, key): - key = self._helper[0](key) - return self._d.has_key(key) + def __setitem__(self, k, hdrs): + k = self._kconv(k) + first = None + new = self._filter_lst(k, self.lst) + for i in hdrs: + new.append((k, i)) + self.lst = new - def setdefault(self, key, default=None): - key = self._helper[0](key) - return self._d.setdefault(key, default) + def __delitem__(self, k): + self.lst = self._filter_lst(k, self.lst) - def keys(self): - return self._d.keys() + def __contains__(self, k): + for i in self.lst: + if self._kconv(i[0]) == k: + return True + return False - def extend(self, key, value): - if not self.has_key(key): - self[key] = [] - self[key].extend(value) - - def append(self, key, value): - self.extend(key, [value]) - - def itemPairs(self): - """ - Yield all possible pairs of items. - """ - for i in self.keys(): - for j in self[i]: - yield (i, j) + def add(self, key, value): + self.lst.append([key, str(value)]) def get_state(self): - return list(self.itemPairs()) + return [tuple(i) for i in self.lst] @classmethod def from_state(klass, state): - md = klass() - for i in state: - md.append(*i) - return md + return klass([list(i) for i in state]) + def copy(self): + lst = copy.deepcopy(self.lst) + return Headers(lst) -class Headers(MultiDict): - """ - A dictionary-like class for keeping track of HTTP headers. - - It is case insensitive, and __repr__ formats the headers correcty for - output to the server. - """ - _helper = (_caseless,) def __repr__(self): """ Returns a string containing a formatted header string. """ headerElements = [] - for key in sorted(self.keys()): - for val in self[key]: - headerElements.append(key + ": " + val) + for itm in self.lst: + headerElements.append(itm[0] + ": " + itm[1]) headerElements.append("") return "\r\n".join(headerElements) @@ -284,7 +244,7 @@ class Headers(MultiDict): """ Match the regular expression against each header (key, value) pair. """ - for k, v in self.itemPairs(): + for k, v in self.lst: s = "%s: %s"%(k, v) if re.search(expr, s): return True @@ -295,6 +255,7 @@ class Headers(MultiDict): Read a set of headers from a file pointer. Stop once a blank line is reached. """ + ret = [] name = '' while 1: line = fp.readline() @@ -302,18 +263,15 @@ class Headers(MultiDict): break if line[0] in ' \t': # continued header - self[name][-1] = self[name][-1] + '\r\n ' + line.strip() + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: i = line.find(':') # We're being liberal in what we accept, here. if i > 0: name = line[:i] value = line[i+1:].strip() - if self.has_key(name): - # merge value - self.append(name, value) - else: - self[name] = [value] + ret.append([name, value]) + self.lst = ret def pretty_size(size): diff --git a/test/test_server.py b/test/test_server.py index e9b611657..1e3c1df49 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -44,7 +44,7 @@ class uProxy(tutils.ProxTest): l = self.log() assert l[0].address - assert l[1].headers.has_key("host") + assert "host" in l[1].headers assert l[2].code == 200 def test_https(self): @@ -55,7 +55,7 @@ class uProxy(tutils.ProxTest): l = self.log() assert l[0].address - assert l[1].headers.has_key("host") + assert "host" in l[1].headers assert l[2].code == 200 # Disable these two for now: they take a long time. diff --git a/test/test_utils.py b/test/test_utils.py index b64db918a..d5957872d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -43,88 +43,6 @@ class uData(libpry.AutoTree): libpry.raises("does not exist", utils.data.path, "nonexistent") -class uMultiDict(libpry.AutoTree): - def setUp(self): - self.md = utils.MultiDict() - - def test_setget(self): - assert not self.md.has_key("foo") - self.md.append("foo", 1) - assert self.md["foo"] == [1] - assert self.md.has_key("foo") - - def test_del(self): - self.md.append("foo", 1) - del self.md["foo"] - assert not self.md.has_key("foo") - - def test_extend(self): - self.md.append("foo", 1) - self.md.extend("foo", [2, 3]) - assert self.md["foo"] == [1, 2, 3] - - def test_extend_err(self): - self.md.append("foo", 1) - libpry.raises("not iterable", self.md.extend, "foo", 2) - - def test_get(self): - self.md.append("foo", 1) - self.md.append("foo", 2) - assert self.md.get("foo") == [1, 2] - assert self.md.get("bar") == None - - def test_caseSensitivity(self): - self.md._helper = (utils._caseless,) - self.md["foo"] = [1] - self.md.append("FOO", 2) - assert self.md["foo"] == [1, 2] - assert self.md["FOO"] == [1, 2] - assert self.md.has_key("FoO") - - def test_dict(self): - self.md.append("foo", 1) - self.md.append("foo", 2) - self.md["bar"] = [3] - assert self.md == self.md - assert dict(self.md) == self.md - - def test_copy(self): - self.md["foo"] = [1, 2] - self.md["bar"] = [3, 4] - md2 = self.md.copy() - assert md2 == self.md - assert id(md2) != id(self.md) - - def test_clear(self): - self.md["foo"] = [1, 2] - self.md["bar"] = [3, 4] - self.md.clear() - assert not self.md.keys() - - def test_setitem(self): - libpry.raises(ValueError, self.md.__setitem__, "foo", "bar") - self.md["foo"] = ["bar"] - assert self.md["foo"] == ["bar"] - - def test_itemPairs(self): - self.md.append("foo", 1) - self.md.append("foo", 2) - self.md.append("bar", 3) - l = list(self.md.itemPairs()) - assert len(l) == 3 - assert ("foo", 1) in l - assert ("foo", 2) in l - assert ("bar", 3) in l - - def test_getset_state(self): - self.md.append("foo", 1) - self.md.append("foo", 2) - self.md.append("bar", 3) - state = self.md.get_state() - nd = utils.MultiDict.from_state(state) - assert nd == self.md - - class uHeaders(libpry.AutoTree): def setUp(self): self.hd = utils.Headers() @@ -168,9 +86,9 @@ class uHeaders(libpry.AutoTree): assert self.hd["header"] == ['one\r\n two'] def test_dictToHeader1(self): - self.hd.append("one", "uno") - self.hd.append("two", "due") - self.hd.append("two", "tre") + self.hd.add("one", "uno") + self.hd.add("two", "due") + self.hd.add("two", "tre") expected = [ "one: uno\r\n", "two: due\r\n", @@ -191,21 +109,34 @@ class uHeaders(libpry.AutoTree): def test_match_re(self): h = utils.Headers() - h.append("one", "uno") - h.append("two", "due") - h.append("two", "tre") + h.add("one", "uno") + h.add("two", "due") + h.add("two", "tre") assert h.match_re("uno") assert h.match_re("two: due") assert not h.match_re("nonono") def test_getset_state(self): - self.hd.append("foo", 1) - self.hd.append("foo", 2) - self.hd.append("bar", 3) + self.hd.add("foo", 1) + self.hd.add("foo", 2) + self.hd.add("bar", 3) state = self.hd.get_state() nd = utils.Headers.from_state(state) assert nd == self.hd + def test_copy(self): + self.hd.add("foo", 1) + self.hd.add("foo", 2) + self.hd.add("bar", 3) + assert self.hd == self.hd.copy() + + def test_del(self): + self.hd.add("foo", 1) + self.hd.add("Foo", 2) + self.hd.add("bar", 3) + del self.hd["foo"] + assert len(self.hd.lst) == 1 + class uisStringLike(libpry.AutoTree): def test_all(self): @@ -371,7 +302,6 @@ tests = [ upretty_size(), uisStringLike(), uisSequenceLike(), - uMultiDict(), uHeaders(), uData(), upretty_xmlish(),