diff --git a/starlette/datastructures.py b/starlette/datastructures.py index c9ffd4df..22ee32ee 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,3 +1,4 @@ +import itertools import tempfile import typing from collections import namedtuple @@ -290,6 +291,83 @@ class ImmutableMultiDict(typing.Mapping): return f"{self.__class__.__name__}({repr(items)})" +class MultiDict(ImmutableMultiDict): + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + self.setlist(key, [value]) + + def __delitem__(self, key: typing.Any) -> None: + self._list = [(k, v) for k, v in self._list if k != key] + del self._dict[key] + + def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) + + def popitem(self) -> typing.Tuple: + key, value = self._dict.popitem() + self._list = [(k, v) for k, v in self._list if k != key] + return key, value + + def poplist(self, key: typing.Any) -> typing.List: + values = [v for k, v in self._list if k == key] + self.pop(key) + return values + + def clear(self) -> None: + self._dict.clear() + self._list.clear() + + def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + if key not in self: + self._dict[key] = default + self._list.append((key, default)) + + return self[key] + + def setlist(self, key: typing.Any, values: typing.List) -> None: + self.pop(key, None) + if not values: + values = [] + else: + self._dict[key] = values[-1] + self._list.extend(((key, value) for value in values)) + + def appendlist(self, key: typing.Any, value: typing.Any) -> None: + self._list.append((key, value)) + self._dict[key] = value + + def update( + self, + values: typing.Union[ + "MultiDict", + typing.Mapping, + typing.List[typing.Tuple[typing.Any, typing.Any]], + ] = None, + **kwargs: typing.Any, + ) -> None: + if values is None: + items_ = [] # type: typing.List + elif hasattr(values, "multi_items"): + values = typing.cast(MultiDict, values) + items_ = list(values.multi_items()) + elif hasattr(values, "items"): + values = typing.cast(typing.Mapping, values) + items_ = list(values.items()) + else: + values = typing.cast( + typing.List[typing.Tuple[typing.Any, typing.Any]], values + ) + items_ = values + + keys = {k for k, _ in itertools.chain(items_, kwargs.items())} + self._list = [ + *((k, v) for k, v in self._list if k not in keys), + *items_, + *list(kwargs.items()), + ] + self._dict.update(itertools.chain(items_, kwargs.items())) + + class QueryParams(ImmutableMultiDict): """ An immutable multidict. diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index bd33909d..4a96efb8 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -6,6 +6,7 @@ from starlette.datastructures import ( DatabaseURL, FormData, Headers, + MultiDict, MutableHeaders, QueryParams, ) @@ -224,3 +225,100 @@ def test_formdata(): assert FormData(form) == form assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} + + +def test_multidict(): + q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) + assert "a" in q + assert "A" not in q + assert "c" not in q + assert q["a"] == "456" + assert q.get("a") == "456" + assert q.get("nope", default=None) is None + assert q.getlist("a") == ["123", "456"] + assert list(q.keys()) == ["a", "b"] + assert list(q.values()) == ["456", "789"] + assert list(q.items()) == [("a", "456"), ("b", "789")] + assert len(q) == 2 + assert list(q) == ["a", "b"] + assert dict(q) == {"a": "456", "b": "789"} + assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" + assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" + assert MultiDict({"a": "123", "b": "456"}) == MultiDict( + [("a", "123"), ("b", "456")] + ) + assert MultiDict({"a": "123", "b": "456"}) == MultiDict( + [("a", "123"), ("b", "456")] + ) + assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"}) + assert MultiDict() == MultiDict({}) + assert MultiDict({"a": "123", "b": "456"}) != "invalid" + + q = MultiDict([("a", "123"), ("a", "456")]) + assert MultiDict(q) == q + + q = MultiDict([("a", "123"), ("a", "456")]) + q["a"] = "789" + assert q["a"] == "789" + assert q.getlist("a") == ["789"] + + q = MultiDict([("a", "123"), ("a", "456")]) + del q["a"] + assert q.get("a") is None + assert repr(q) == "MultiDict([])" + + q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) + assert q.pop("a") == "456" + assert q.get("a", None) is None + assert repr(q) == "MultiDict([('b', '789')])" + + q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) + item = q.popitem() + assert q.get(item[0]) is None + + q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) + assert q.poplist("a") == ["123", "456"] + assert q.get("a") is None + assert repr(q) == "MultiDict([('b', '789')])" + + q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) + q.clear() + assert q.get("a") is None + assert repr(q) == "MultiDict([])" + + q = MultiDict([("a", "123")]) + q.setlist("a", ["456", "789"]) + assert q.getlist("a") == ["456", "789"] + q.setlist("b", []) + assert q.get("b") is None + + q = MultiDict([("a", "123")]) + assert q.setdefault("a", "456") == "123" + assert q.getlist("a") == ["123"] + assert q.setdefault("b", "456") == "456" + assert q.getlist("b") == ["456"] + assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" + + q = MultiDict([("a", "123")]) + q.appendlist("a", "456") + assert q.getlist("a") == ["123", "456"] + assert repr(q) == "MultiDict([('a', '123'), ('a', '456')])" + + q = MultiDict([("a", "123"), ("b", "456")]) + q.update({"a": "789"}) + assert q.getlist("a") == ["789"] + q == MultiDict([("a", "789"), ("b", "456")]) + + q = MultiDict([("a", "123"), ("b", "456")]) + q.update(q) + assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" + + q = MultiDict([("a", "123"), ("b", "456")]) + q.update(None) + assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" + + q = MultiDict([("a", "123"), ("a", "456")]) + q.update([("a", "123")]) + assert q.getlist("a") == ["123"] + q.update([("a", "456")], a="789", b="123") + assert q == MultiDict([("a", "456"), ("a", "789"), ("b", "123")])