From b1ae0c3621034f1531b9983389ce90be8d140bc6 Mon Sep 17 00:00:00 2001 From: manlix Date: Wed, 16 Feb 2022 13:58:56 +0300 Subject: [PATCH] Add union operators to MutableHeaders (#1240) * Add union operators to MutableHeaders (#1239) * Apply suggestions from code review * Use `TypeError`, not `NotImplemented`. * Add `# type: ignore` to deliberate incorrect usage of types in tests. * Apply suggestions from code review Co-authored-by: Marcelo Trylesinski Co-authored-by: Tom Christie --- starlette/datastructures.py | 15 +++++++++++- tests/test_datastructures.py | 44 ++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 1a8b965e..59863282 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -618,6 +618,19 @@ class MutableHeaders(Headers): for idx in reversed(pop_indexes): del self._list[idx] + def __ior__(self, other: typing.Mapping) -> "MutableHeaders": + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + self.update(other) + return self + + def __or__(self, other: typing.Mapping) -> "MutableHeaders": + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + new = self.mutablecopy() + new.update(other) + return new + @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return self._list @@ -636,7 +649,7 @@ class MutableHeaders(Headers): self._list.append((set_key, set_value)) return value - def update(self, other: dict) -> None: + def update(self, other: typing.Mapping) -> None: for key, val in other.items(): self[key] = val diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index b110aa8b..22e377c9 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -162,6 +162,50 @@ def test_mutable_headers(): assert h.raw == [(b"b", b"4")] +def test_mutable_headers_merge(): + h = MutableHeaders() + h = h | MutableHeaders({"a": "1"}) + assert isinstance(h, MutableHeaders) + assert dict(h) == {"a": "1"} + assert h.items() == [("a", "1")] + assert h.raw == [(b"a", b"1")] + + +def test_mutable_headers_merge_dict(): + h = MutableHeaders() + h = h | {"a": "1"} + assert isinstance(h, MutableHeaders) + assert dict(h) == {"a": "1"} + assert h.items() == [("a", "1")] + assert h.raw == [(b"a", b"1")] + + +def test_mutable_headers_update(): + h = MutableHeaders() + h |= MutableHeaders({"a": "1"}) + assert isinstance(h, MutableHeaders) + assert dict(h) == {"a": "1"} + assert h.items() == [("a", "1")] + assert h.raw == [(b"a", b"1")] + + +def test_mutable_headers_update_dict(): + h = MutableHeaders() + h |= {"a": "1"} + assert isinstance(h, MutableHeaders) + assert dict(h) == {"a": "1"} + assert h.items() == [("a", "1")] + assert h.raw == [(b"a", b"1")] + + +def test_mutable_headers_merge_not_mapping(): + h = MutableHeaders() + with pytest.raises(TypeError): + h |= {"not_mapping"} # type: ignore + with pytest.raises(TypeError): + h | {"not_mapping"} # type: ignore + + def test_headers_mutablecopy(): h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")]) c = h.mutablecopy()