Rewrite Headers object to preserve order and case.

This commit is contained in:
Aldo Cortesi 2011-07-14 15:59:27 +12:00
parent b6e1bf63c3
commit 1c9e7b982a
7 changed files with 95 additions and 203 deletions

View File

@ -93,7 +93,7 @@ def format_flow(f, focus, extended=False, padding=2):
txt.append(("goodcode", str(f.response.code)))
else:
txt.append(("error", str(f.response.code)))
t = f.response.headers.get("content-type")
t = f.response.headers["content-type"]
if t:
t = t[0].split(";")[0]
txt.append(("text", " %s"%t))
@ -295,7 +295,11 @@ class ConnectionView(WWrap):
def _conn_text(self, conn, viewmode):
if conn:
return self.master._cached_conn_text(conn.content, tuple(conn.headers.itemPairs()), viewmode)
return self.master._cached_conn_text(
conn.content,
tuple([tuple(i) for i in conn.headers.lst]),
viewmode
)
else:
return urwid.ListBox([])
@ -485,7 +489,7 @@ class ConnectionView(WWrap):
else:
conn = self.flow.response
if conn.content:
t = conn.headers.get("content-type", [None])
t = conn.headers["content-type"] or [None]
t = t[0]
self.master.spawn_external_viewer(conn.content, t)
elif key == "b":

View File

@ -79,7 +79,7 @@ class _Rex(_Action):
raise ValueError, "Cannot compile expression."
def _check_content_type(expr, o):
val = o.headers.get("content-type")
val = o.headers["content-type"]
if val and re.search(expr, val[0]):
return True
return False

View File

@ -101,7 +101,7 @@ class ServerPlaybackState:
if self.headers:
hdrs = []
for i in self.headers:
v = r.headers.get(i, [])
v = r.headers[i]
# Slightly subtle: we need to convert everything to strings
# to prevent a mismatch between unicode/non-unicode.
v = [str(x) for x in v]
@ -139,7 +139,7 @@ class StickyCookieState:
)
def handle_response(self, f):
for i in f.response.headers.get("set-cookie", []):
for i in f.response.headers["set-cookie"]:
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
c = Cookie.SimpleCookie(i)
@ -158,9 +158,10 @@ class StickyCookieState:
f.request.path.startswith(i[2])
]
if all(match):
l = f.request.headers.setdefault("cookie", [])
l = f.request.headers["cookie"]
f.request.stickycookie = True
l.append(self.jar[i].output(header="").strip())
f.request.headers["cookie"] = l
class StickyAuthState:

View File

@ -56,11 +56,11 @@ def read_chunked(fp):
def read_http_body(rfile, connection, headers, all):
if headers.has_key('transfer-encoding'):
if 'transfer-encoding' in headers:
if not ",".join(headers["transfer-encoding"]) == "chunked":
raise IOError('Invalid transfer-encoding')
content = read_chunked(rfile)
elif headers.has_key("content-length"):
elif "content-length" in headers:
content = rfile.read(int(headers["content-length"][0]))
elif all:
content = rfile.read()
@ -152,8 +152,7 @@ class Request(controller.Msg):
"if-none-match",
]
for i in delheaders:
if i in self.headers:
del self.headers[i]
del self.headers[i]
def set_replay(self):
self.client_conn = None
@ -251,7 +250,7 @@ class Request(controller.Msg):
utils.try_del(headers, 'connection')
utils.try_del(headers, 'content-length')
utils.try_del(headers, 'transfer-encoding')
if not headers.has_key('host'):
if not 'host' in headers:
headers["host"] = [self.hostport()]
content = self.content
if content is not None:
@ -321,7 +320,7 @@ class Response(controller.Msg):
new = mktime_tz(d) + delta
self.headers[i] = [formatdate(new)]
c = []
for i in self.headers.get("set-cookie", []):
for i in self.headers["set-cookie"]:
c.append(self._refresh_cookie(i, delta))
if c:
self.headers["set-cookie"] = c
@ -656,7 +655,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme = "https"
headers = utils.Headers()
headers.read(self.rfile)
if host is None and headers.has_key("host"):
if host is None and "host" in headers:
netloc = headers["host"][0]
if ':' in netloc:
host, port = string.split(netloc, ':')
@ -670,7 +669,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
port = int(port)
if host is None:
raise ProxyError(400, 'Invalid request: %s'%request)
if headers.has_key('expect'):
if "expect" in headers:
expect = ",".join(headers['expect'])
if expect == "100-continue" and httpminor >= 1:
self.wfile.write('HTTP/1.1 100 Continue\r\n')
@ -681,7 +680,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
raise ProxyError(417, 'Unmet expect: %s'%expect)
if httpminor == 0:
client_conn.close = True
if headers.has_key('connection'):
if "connection" in headers:
for value in ",".join(headers['connection']).split(","):
value = value.strip()
if value == "close":

View File

@ -12,7 +12,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import re, os, subprocess, datetime, textwrap, errno, sys, time, functools
import re, os, subprocess, datetime, textwrap, errno, sys, time, functools, copy
import json
CERT_SLEEP_TIME = 1
@ -164,10 +164,6 @@ def isSequenceLike(anobj):
return 1
def _caseless(s):
return s.lower()
def try_del(dict, key):
try:
del dict[key]
@ -175,108 +171,72 @@ def try_del(dict, key):
pass
class MultiDict:
"""
Simple wrapper around a dictionary to make holding multiple objects per
key easier.
class Headers:
def __init__(self, lst=None):
if lst:
self.lst = lst
else:
self.lst = []
Note that this class assumes that keys are strings.
Keys have no order, but the order in which values are added to a key is
preserved.
"""
# This ridiculous bit of subterfuge is needed to prevent the class from
# treating this as a bound method.
_helper = (str,)
def __init__(self):
self._d = dict()
def copy(self):
m = self.__class__()
m._d = self._d.copy()
return m
def clear(self):
return self._d.clear()
def get(self, key, d=None):
key = self._helper[0](key)
return self._d.get(key, d)
def __contains__(self, key):
key = self._helper[0](key)
return self._d.__contains__(key)
def _kconv(self, s):
return s.lower()
def __eq__(self, other):
return dict(self) == dict(other)
return self.lst == other.lst
def __delitem__(self, key):
self._d.__delitem__(key)
def __getitem__(self, k):
ret = []
k = self._kconv(k)
for i in self.lst:
if self._kconv(i[0]) == k:
ret.append(i[1])
return ret
def __getitem__(self, key):
key = self._helper[0](key)
return self._d.__getitem__(key)
def __setitem__(self, key, value):
if not isSequenceLike(value):
raise ValueError, "Cannot insert non-sequence."
key = self._helper[0](key)
return self._d.__setitem__(key, value)
def _filter_lst(self, k, lst):
new = []
for i in lst:
if self._kconv(i[0]) != k:
new.append(i)
return new
def has_key(self, key):
key = self._helper[0](key)
return self._d.has_key(key)
def __setitem__(self, k, hdrs):
k = self._kconv(k)
first = None
new = self._filter_lst(k, self.lst)
for i in hdrs:
new.append((k, i))
self.lst = new
def setdefault(self, key, default=None):
key = self._helper[0](key)
return self._d.setdefault(key, default)
def __delitem__(self, k):
self.lst = self._filter_lst(k, self.lst)
def keys(self):
return self._d.keys()
def __contains__(self, k):
for i in self.lst:
if self._kconv(i[0]) == k:
return True
return False
def extend(self, key, value):
if not self.has_key(key):
self[key] = []
self[key].extend(value)
def append(self, key, value):
self.extend(key, [value])
def itemPairs(self):
"""
Yield all possible pairs of items.
"""
for i in self.keys():
for j in self[i]:
yield (i, j)
def add(self, key, value):
self.lst.append([key, str(value)])
def get_state(self):
return list(self.itemPairs())
return [tuple(i) for i in self.lst]
@classmethod
def from_state(klass, state):
md = klass()
for i in state:
md.append(*i)
return md
return klass([list(i) for i in state])
def copy(self):
lst = copy.deepcopy(self.lst)
return Headers(lst)
class Headers(MultiDict):
"""
A dictionary-like class for keeping track of HTTP headers.
It is case insensitive, and __repr__ formats the headers correcty for
output to the server.
"""
_helper = (_caseless,)
def __repr__(self):
"""
Returns a string containing a formatted header string.
"""
headerElements = []
for key in sorted(self.keys()):
for val in self[key]:
headerElements.append(key + ": " + val)
for itm in self.lst:
headerElements.append(itm[0] + ": " + itm[1])
headerElements.append("")
return "\r\n".join(headerElements)
@ -284,7 +244,7 @@ class Headers(MultiDict):
"""
Match the regular expression against each header (key, value) pair.
"""
for k, v in self.itemPairs():
for k, v in self.lst:
s = "%s: %s"%(k, v)
if re.search(expr, s):
return True
@ -295,6 +255,7 @@ class Headers(MultiDict):
Read a set of headers from a file pointer. Stop once a blank line
is reached.
"""
ret = []
name = ''
while 1:
line = fp.readline()
@ -302,18 +263,15 @@ class Headers(MultiDict):
break
if line[0] in ' \t':
# continued header
self[name][-1] = self[name][-1] + '\r\n ' + line.strip()
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
else:
i = line.find(':')
# We're being liberal in what we accept, here.
if i > 0:
name = line[:i]
value = line[i+1:].strip()
if self.has_key(name):
# merge value
self.append(name, value)
else:
self[name] = [value]
ret.append([name, value])
self.lst = ret
def pretty_size(size):

View File

@ -44,7 +44,7 @@ class uProxy(tutils.ProxTest):
l = self.log()
assert l[0].address
assert l[1].headers.has_key("host")
assert "host" in l[1].headers
assert l[2].code == 200
def test_https(self):
@ -55,7 +55,7 @@ class uProxy(tutils.ProxTest):
l = self.log()
assert l[0].address
assert l[1].headers.has_key("host")
assert "host" in l[1].headers
assert l[2].code == 200
# Disable these two for now: they take a long time.

View File

@ -43,88 +43,6 @@ class uData(libpry.AutoTree):
libpry.raises("does not exist", utils.data.path, "nonexistent")
class uMultiDict(libpry.AutoTree):
def setUp(self):
self.md = utils.MultiDict()
def test_setget(self):
assert not self.md.has_key("foo")
self.md.append("foo", 1)
assert self.md["foo"] == [1]
assert self.md.has_key("foo")
def test_del(self):
self.md.append("foo", 1)
del self.md["foo"]
assert not self.md.has_key("foo")
def test_extend(self):
self.md.append("foo", 1)
self.md.extend("foo", [2, 3])
assert self.md["foo"] == [1, 2, 3]
def test_extend_err(self):
self.md.append("foo", 1)
libpry.raises("not iterable", self.md.extend, "foo", 2)
def test_get(self):
self.md.append("foo", 1)
self.md.append("foo", 2)
assert self.md.get("foo") == [1, 2]
assert self.md.get("bar") == None
def test_caseSensitivity(self):
self.md._helper = (utils._caseless,)
self.md["foo"] = [1]
self.md.append("FOO", 2)
assert self.md["foo"] == [1, 2]
assert self.md["FOO"] == [1, 2]
assert self.md.has_key("FoO")
def test_dict(self):
self.md.append("foo", 1)
self.md.append("foo", 2)
self.md["bar"] = [3]
assert self.md == self.md
assert dict(self.md) == self.md
def test_copy(self):
self.md["foo"] = [1, 2]
self.md["bar"] = [3, 4]
md2 = self.md.copy()
assert md2 == self.md
assert id(md2) != id(self.md)
def test_clear(self):
self.md["foo"] = [1, 2]
self.md["bar"] = [3, 4]
self.md.clear()
assert not self.md.keys()
def test_setitem(self):
libpry.raises(ValueError, self.md.__setitem__, "foo", "bar")
self.md["foo"] = ["bar"]
assert self.md["foo"] == ["bar"]
def test_itemPairs(self):
self.md.append("foo", 1)
self.md.append("foo", 2)
self.md.append("bar", 3)
l = list(self.md.itemPairs())
assert len(l) == 3
assert ("foo", 1) in l
assert ("foo", 2) in l
assert ("bar", 3) in l
def test_getset_state(self):
self.md.append("foo", 1)
self.md.append("foo", 2)
self.md.append("bar", 3)
state = self.md.get_state()
nd = utils.MultiDict.from_state(state)
assert nd == self.md
class uHeaders(libpry.AutoTree):
def setUp(self):
self.hd = utils.Headers()
@ -168,9 +86,9 @@ class uHeaders(libpry.AutoTree):
assert self.hd["header"] == ['one\r\n two']
def test_dictToHeader1(self):
self.hd.append("one", "uno")
self.hd.append("two", "due")
self.hd.append("two", "tre")
self.hd.add("one", "uno")
self.hd.add("two", "due")
self.hd.add("two", "tre")
expected = [
"one: uno\r\n",
"two: due\r\n",
@ -191,21 +109,34 @@ class uHeaders(libpry.AutoTree):
def test_match_re(self):
h = utils.Headers()
h.append("one", "uno")
h.append("two", "due")
h.append("two", "tre")
h.add("one", "uno")
h.add("two", "due")
h.add("two", "tre")
assert h.match_re("uno")
assert h.match_re("two: due")
assert not h.match_re("nonono")
def test_getset_state(self):
self.hd.append("foo", 1)
self.hd.append("foo", 2)
self.hd.append("bar", 3)
self.hd.add("foo", 1)
self.hd.add("foo", 2)
self.hd.add("bar", 3)
state = self.hd.get_state()
nd = utils.Headers.from_state(state)
assert nd == self.hd
def test_copy(self):
self.hd.add("foo", 1)
self.hd.add("foo", 2)
self.hd.add("bar", 3)
assert self.hd == self.hd.copy()
def test_del(self):
self.hd.add("foo", 1)
self.hd.add("Foo", 2)
self.hd.add("bar", 3)
del self.hd["foo"]
assert len(self.hd.lst) == 1
class uisStringLike(libpry.AutoTree):
def test_all(self):
@ -371,7 +302,6 @@ tests = [
upretty_size(),
uisStringLike(),
uisSequenceLike(),
uMultiDict(),
uHeaders(),
uData(),
upretty_xmlish(),