diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index c5ac45918..88c768706 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -169,8 +169,8 @@ def parse_set_cookie_headers(headers): class CookieAttrs(ImmutableMultiDict): @staticmethod - def _kconv(v): - return v.lower() + def _kconv(key): + return key.lower() @staticmethod def _reduce_values(values): diff --git a/netlib/http/message.py b/netlib/http/message.py index 3c731ea6e..db4054b14 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -279,7 +279,7 @@ class MultiDictView(MultiDict): """ def __init__(self, attr, message): - if False: + if False: # pragma: no cover # We do not want to call the parent constructor here as that # would cause an unnecessary parse/unparse pass. # This is here to silence linters. Message diff --git a/netlib/multidict.py b/netlib/multidict.py index 32d5bfc2e..a359d46b0 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -35,12 +35,20 @@ class MultiDict(MutableMapping, Serializable): @staticmethod @abstractmethod def _reduce_values(values): - pass + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ @staticmethod @abstractmethod - def _kconv(v): - pass + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ def __getitem__(self, key): values = self.get_all(key) diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 48d3b3233..cd2ca9d11 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -41,17 +41,7 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) - def test_getitem(self): - headers = Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" - with raises(KeyError): - _ = headers["Accept"] - - headers = self._2host() - assert headers["Host"] == "example.com, example.org" - - def test_str(self): + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" @@ -64,93 +54,6 @@ class TestHeaders(object): headers = Headers() assert bytes(headers) == b"" - def test_setitem(self): - headers = Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = Headers(Host="example.com") - assert list(headers.keys()) == ["Host"] - - headers = self._2host() - assert list(headers.keys()) == ["Host"] - - def test_eq_ne(self): - headers1 = Headers(Host="example.com") - headers2 = Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = Headers(Host="example.com") - headers2 = Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == Headers.from_state(headers.get_state()) - - headers2 = Headers() - assert headers != headers2 - headers2.set_state(headers.get_state()) - assert headers == headers2 - def test_replace_simple(self): headers = Headers(Host="example.com", Accept="text/plain") replacements = headers.replace("Host: ", "X-Host: ") diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py new file mode 100644 index 000000000..ceea38064 --- /dev/null +++ b/test/netlib/test_multidict.py @@ -0,0 +1,217 @@ +from netlib import tutils +from netlib.multidict import MultiDict, ImmutableMultiDict + + +class _TMulti(object): + @staticmethod + def _reduce_values(values): + return values[0] + + @staticmethod + def _kconv(key): + return key.lower() + + +class TMultiDict(_TMulti, MultiDict): + pass + + +class TImmutableMultiDict(_TMulti, ImmutableMultiDict): + pass + + +class TestMultiDict(object): + @staticmethod + def _multi(): + return TMultiDict(( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam") + )) + + def test_init(self): + md = TMultiDict() + assert len(md) == 0 + + md = TMultiDict([("foo", "bar")]) + assert len(md) == 1 + assert md.fields == (("foo", "bar"),) + + def test_repr(self): + assert repr(self._multi()) == ( + "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" + ) + + def test_getitem(self): + md = TMultiDict([("foo", "bar")]) + assert "foo" in md + assert "Foo" in md + assert md["foo"] == "bar" + + with tutils.raises(KeyError): + _ = md["bar"] + + md_multi = TMultiDict( + [("foo", "a"), ("foo", "b")] + ) + assert md_multi["foo"] == "a" + + def test_setitem(self): + md = TMultiDict() + md["foo"] = "bar" + assert md.fields == (("foo", "bar"),) + + md["foo"] = "baz" + assert md.fields == (("foo", "baz"),) + + md["bar"] = "bam" + assert md.fields == (("foo", "baz"), ("bar", "bam")) + + def test_delitem(self): + md = self._multi() + del md["foo"] + assert "foo" not in md + assert "bar" in md + + with tutils.raises(KeyError): + del md["foo"] + + del md["bar"] + assert md.fields == () + + def test_iter(self): + md = self._multi() + assert list(md.__iter__()) == ["foo", "bar"] + + def test_len(self): + md = TMultiDict() + assert len(md) == 0 + + md = self._multi() + assert len(md) == 2 + + def test_eq(self): + assert TMultiDict() == TMultiDict() + assert not (TMultiDict() == 42) + + md1 = self._multi() + md2 = self._multi() + assert md1 == md2 + md1.fields = md1.fields[1:] + md1.fields[:1] + assert not (md1 == md2) + + def test_ne(self): + assert not TMultiDict() != TMultiDict() + assert TMultiDict() != self._multi() + assert TMultiDict() != 42 + + def test_get_all(self): + md = self._multi() + assert md.get_all("foo") == ["bar"] + assert md.get_all("bar") == ["baz", "bam"] + assert md.get_all("baz") == [] + + def test_set_all(self): + md = TMultiDict() + md.set_all("foo", ["bar", "baz"]) + assert md.fields == (("foo", "bar"), ("foo", "baz")) + + md = TMultiDict(( + ("a", "b"), + ("x", "x"), + ("c", "d"), + ("X", "x"), + ("e", "f"), + )) + md.set_all("x", ["1", "2", "3"]) + assert md.fields == ( + ("a", "b"), + ("x", "1"), + ("c", "d"), + ("x", "2"), + ("e", "f"), + ("x", "3"), + ) + md.set_all("x", ["4"]) + assert md.fields == ( + ("a", "b"), + ("x", "4"), + ("c", "d"), + ("e", "f"), + ) + + def test_add(self): + md = self._multi() + md.add("foo", "foo") + assert md.fields == ( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam"), + ("foo", "foo") + ) + + def test_insert(self): + md = TMultiDict([("b", "b")]) + md.insert(0, "a", "a") + md.insert(2, "c", "c") + assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + + def test_keys(self): + md = self._multi() + assert list(md.keys()) == ["foo", "bar"] + assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + + def test_values(self): + md = self._multi() + assert list(md.values()) == ["bar", "baz"] + assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + + def test_items(self): + md = self._multi() + assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] + assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + + def test_to_dict(self): + md = self._multi() + assert md.to_dict() == { + "foo": "bar", + "bar": ["baz", "bam"] + } + + def test_state(self): + md = self._multi() + assert len(md.get_state()) == 3 + assert md == TMultiDict.from_state(md.get_state()) + + md2 = TMultiDict() + assert md != md2 + md2.set_state(md.get_state()) + assert md == md2 + + +class TestImmutableMultiDict(object): + def test_modify(self): + md = TImmutableMultiDict() + with tutils.raises(TypeError): + md["foo"] = "bar" + + with tutils.raises(TypeError): + del md["foo"] + + with tutils.raises(TypeError): + md.add("foo", "bar") + + def test_with_delitem(self): + md = TImmutableMultiDict([("foo", "bar")]) + assert md.with_delitem("foo").fields == () + assert md.fields == (("foo", "bar"),) + + def test_with_set_all(self): + md = TImmutableMultiDict() + assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) + assert md.fields == () + + def test_with_insert(self): + md = TImmutableMultiDict() + assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) + assert md.fields == () \ No newline at end of file