From b786f68d60a3eec6292a10eafd58363c10130090 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 23 Jan 2019 11:19:21 +0000 Subject: [PATCH] Add ImmutableMultiDict (#343) --- starlette/datastructures.py | 173 +++++++++++++++++------------------ starlette/formparsers.py | 4 +- tests/test_datastructures.py | 36 +++----- 3 files changed, 100 insertions(+), 113 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index c9c3fbb0..b4f94ace 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -222,36 +222,28 @@ class CommaSeparatedStrings(Sequence): return ", ".join([repr(item) for item in self]) -class QueryParams(typing.Mapping[str, str]): - """ - An immutable multidict. - """ - +class ImmutableMultiDict(typing.Mapping): def __init__( self, - params: typing.Union["QueryParams", typing.Mapping[str, str]] = None, - items: typing.List[typing.Tuple[str, str]] = None, - query_string: str = None, - scope: Scope = None, + value: typing.Union[ + "ImmutableMultiDict", + typing.Mapping, + typing.List[typing.Tuple[typing.Any, typing.Any]], + ] = None, ) -> None: - _items = [] # type: typing.List[typing.Tuple[str, str]] - if params is not None: - assert items is None, "Cannot set both 'params' and 'items'" - assert query_string is None, "Cannot set both 'params' and 'query_string'" - assert scope is None, "Cannot set both 'params' and 'scope'" - if isinstance(params, QueryParams): - _items = list(params.multi_items()) - else: - _items = list(params.items()) - elif items is not None: - assert query_string is None, "Cannot set both 'items' and 'query_string'" - assert scope is None, "Cannot set both 'items' and 'scope'" - _items = list(items) - elif query_string is not None: - assert scope is None, "Cannot set both 'query_string' and 'scope'" - _items = parse_qsl(query_string) - elif scope is not None: - _items = parse_qsl(scope["query_string"].decode("latin-1")) + if value is None: + _items = [] # type: typing.List[typing.Tuple[typing.Any, typing.Any]] + elif hasattr(value, "multi_items"): + value = typing.cast(ImmutableMultiDict, value) + _items = list(value.multi_items()) + elif hasattr(value, "items"): + value = typing.cast(typing.Mapping, value) + _items = list(value.items()) + else: + value = typing.cast( + typing.List[typing.Tuple[typing.Any, typing.Any]], value + ) + _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items @@ -259,14 +251,14 @@ class QueryParams(typing.Mapping[str, str]): def getlist(self, key: typing.Any) -> typing.List[str]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.List[str]: # type: ignore - return list(self._dict.keys()) + def keys(self) -> typing.KeysView: + return self._dict.keys() - def values(self) -> typing.List[str]: # type: ignore - return list(self._dict.values()) + def values(self) -> typing.ValuesView: + return self._dict.values() - def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore - return list(self._dict.items()) + def items(self) -> typing.ItemsView: + return self._dict.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: return list(self._list) @@ -289,15 +281,55 @@ class QueryParams(typing.Mapping[str, str]): return len(self._dict) def __eq__(self, other: typing.Any) -> bool: - if not isinstance(other, QueryParams): + if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) + def __repr__(self) -> str: + items = self.multi_items() + return f"{self.__class__.__name__}({repr(items)})" + + +class QueryParams(ImmutableMultiDict): + """ + An immutable multidict. + """ + + def __init__( + self, + value: typing.Union[ + "ImmutableMultiDict", + typing.Mapping, + typing.List[typing.Tuple[typing.Any, typing.Any]], + str, + ] = None, + scope: Scope = None, + **kwargs: typing.Any, + ) -> None: + if kwargs: # pragma: no cover + # Backwards compatability. We now just use a single argument to + # cover all cases, except for the initialize-by-ASGI-scope case. + # + # This compat case should be removed in 0.10.x + value = kwargs.pop("params", value) + value = kwargs.pop("items", value) + value = kwargs.pop("query_string", value) + assert not kwargs, "Unknown parameter" + + if scope is not None: + assert value is None, "Cannot set both 'value' and 'scope'" + value = scope["query_string"].decode("latin-1") + + if isinstance(value, str): + super().__init__(parse_qsl(value)) + else: + super().__init__(value) + def __str__(self) -> str: return urlencode(self._list) def __repr__(self) -> str: - return f"{self.__class__.__name__}(query_string={repr(str(self))})" + return f"{self.__class__.__name__}({repr(str(self))})" class UploadFile: @@ -323,69 +355,30 @@ class UploadFile: FormValue = typing.Union[str, UploadFile] -class FormData(typing.Mapping[str, FormValue]): +class FormData(ImmutableMultiDict): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, - form: typing.Union["FormData", typing.Mapping[str, FormValue]] = None, - items: typing.List[typing.Tuple[str, FormValue]] = None, + value: typing.Union[ + "FormData", + typing.Mapping[str, FormValue], + typing.List[typing.Tuple[str, FormValue]], + ] = None, + **kwargs: typing.Any, ) -> None: - _items = [] # type: typing.List[typing.Tuple[str, FormValue]] - if form is not None: - assert items is None, "Cannot set both 'form' and 'items'" - if isinstance(form, FormData): - _items = list(form.multi_items()) - else: - _items = list(form.items()) - elif items is not None: - _items = list(items) + if kwargs: # pragma: no cover + # Backwards compatability. We now just use a single argument to + # cover all cases. + # + # This compat case should be removed in 0.10.x + value = kwargs.pop("form", value) + value = kwargs.pop("items", value) + assert not kwargs, "Unknown parameter" - self._dict = {k: v for k, v in _items} - self._list = _items - - def getlist(self, key: typing.Any) -> typing.List[FormValue]: - return [item_value for item_key, item_value in self._list if item_key == key] - - def keys(self) -> typing.List[str]: # type: ignore - return list(self._dict.keys()) - - def values(self) -> typing.List[FormValue]: # type: ignore - return list(self._dict.values()) - - def items(self) -> typing.List[typing.Tuple[str, FormValue]]: # type: ignore - return list(self._dict.items()) - - def multi_items(self) -> typing.List[typing.Tuple[str, FormValue]]: - return list(self._list) - - def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: - if key in self._dict: - return self._dict[key] - return default - - def __getitem__(self, key: typing.Any) -> FormValue: - return self._dict[key] - - def __contains__(self, key: typing.Any) -> bool: - return key in self._dict - - def __iter__(self) -> typing.Iterator[typing.Any]: - return iter(self.keys()) - - def __len__(self) -> int: - return len(self._dict) - - def __eq__(self, other: typing.Any) -> bool: - if not isinstance(other, FormData): - return False - return sorted(self._list) == sorted(other._list) - - def __repr__(self) -> str: - items = self.multi_items() - return f"{self.__class__.__name__}(items={repr(items)})" + super().__init__(value) class Headers(typing.Mapping[str, str]): diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 223119ff..74ec7075 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -102,7 +102,7 @@ class FormParser: elif message_type == FormMessage.END: pass - return FormData(items=items) + return FormData(items) class MultiPartParser: @@ -218,4 +218,4 @@ class MultiPartParser: pass parser.finalize() - return FormData(items=items) + return FormData(items) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index c811bac9..bd33909d 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -170,7 +170,7 @@ def test_headers_mutablecopy(): def test_queryparams(): - q = QueryParams(query_string="a=123&a=456&b=789") + q = QueryParams("a=123&a=456&b=789") assert "a" in q assert "A" not in q assert "c" not in q @@ -178,36 +178,32 @@ def test_queryparams(): assert q.get("a") == "456" assert q.get("nope", default=None) is None assert q.getlist("a") == ["123", "456"] - assert q.keys() == ["a", "b"] - assert q.values() == ["456", "789"] - assert q.items() == [("a", "456"), ("b", "789")] + assert list(q.keys()) == ["a", "b"] + assert list(q.values()) == ["456", "789"] + assert list(q.items()) == [("a", "456"), ("b", "789")] assert len(q) == 2 assert list(q) == ["a", "b"] assert dict(q) == {"a": "456", "b": "789"} assert str(q) == "a=123&a=456&b=789" - assert repr(q) == "QueryParams(query_string='a=123&a=456&b=789')" + assert repr(q) == "QueryParams('a=123&a=456&b=789')" assert QueryParams({"a": "123", "b": "456"}) == QueryParams( - items=[("a", "123"), ("b", "456")] - ) - assert QueryParams({"a": "123", "b": "456"}) == QueryParams( - query_string="a=123&b=456" + [("a", "123"), ("b", "456")] ) + assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456") assert QueryParams({"a": "123", "b": "456"}) == QueryParams( {"b": "456", "a": "123"} ) assert QueryParams() == QueryParams({}) - assert QueryParams(items=[("a", "123"), ("a", "456")]) == QueryParams( - query_string="a=123&a=456" - ) + assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456") assert QueryParams({"a": "123", "b": "456"}) != "invalid" - q = QueryParams(items=[("a", "123"), ("a", "456")]) + q = QueryParams([("a", "123"), ("a", "456")]) assert QueryParams(q) == q def test_formdata(): upload = io.BytesIO(b"test") - form = FormData(items=[("a", "123"), ("a", "456"), ("b", upload)]) + form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) assert "a" in form assert "A" not in form assert "c" not in form @@ -215,18 +211,16 @@ def test_formdata(): assert form.get("a") == "456" assert form.get("nope", default=None) is None assert form.getlist("a") == ["123", "456"] - assert form.keys() == ["a", "b"] - assert form.values() == ["456", upload] - assert form.items() == [("a", "456"), ("b", upload)] + assert list(form.keys()) == ["a", "b"] + assert list(form.values()) == ["456", upload] + assert list(form.items()) == [("a", "456"), ("b", upload)] assert len(form) == 2 assert list(form) == ["a", "b"] assert dict(form) == {"a": "456", "b": upload} assert ( repr(form) - == "FormData(items=[('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" + == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" ) assert FormData(form) == form - assert FormData({"a": "123", "b": "789"}) == FormData( - items=[("a", "123"), ("b", "789")] - ) + assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}