mitmproxy/netlib/multidict.py

249 lines
6.8 KiB
Python
Raw Normal View History

from __future__ import absolute_import, print_function, division
from abc import ABCMeta, abstractmethod
from typing import Tuple, TypeVar
try:
from collections.abc import MutableMapping
except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
import six
from .utils import Serializable
@six.add_metaclass(ABCMeta)
class MultiDict(MutableMapping, Serializable):
def __init__(self, fields=None):
# it is important for us that .fields is immutable, so that we can easily
# detect changes to it.
self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
def __repr__(self):
fields = tuple(
repr(field)
for field in self.fields
)
return "{cls}[{fields}]".format(
cls=type(self).__name__,
fields=", ".join(fields)
)
@staticmethod
@abstractmethod
def _reduce_values(values):
2016-05-20 18:04:27 +00:00
"""
If a user accesses multidict["foo"], this method
reduces all values for "foo" to a single value that is returned.
For example, HTTP headers are folded, whereas we will just take
the first cookie we found with that name.
"""
@staticmethod
@abstractmethod
2016-05-20 18:04:27 +00:00
def _kconv(key):
"""
This method converts a key to its canonical representation.
For example, HTTP headers are case-insensitive, so this method returns key.lower().
"""
def __getitem__(self, key):
values = self.get_all(key)
if not values:
raise KeyError(key)
return self._reduce_values(values)
def __setitem__(self, key, value):
self.set_all(key, [value])
def __delitem__(self, key):
if key not in self:
raise KeyError(key)
key = self._kconv(key)
self.fields = tuple(
field for field in self.fields
if key != self._kconv(field[0])
)
def __iter__(self):
seen = set()
for key, _ in self.fields:
key_kconv = self._kconv(key)
if key_kconv not in seen:
seen.add(key_kconv)
yield key
def __len__(self):
return len(set(self._kconv(key) for key, _ in self.fields))
def __eq__(self, other):
if isinstance(other, MultiDict):
return self.fields == other.fields
return False
def __ne__(self, other):
return not self.__eq__(other)
def get_all(self, key):
"""
Return the list of all values for a given key.
If that key is not in the MultiDict, the return value will be an empty list.
"""
key = self._kconv(key)
return [
value
for k, value in self.fields
if self._kconv(k) == key
]
def set_all(self, key, values):
"""
Remove the old values for a key and add new ones.
"""
key_kconv = self._kconv(key)
new_fields = []
for field in self.fields:
if self._kconv(field[0]) == key_kconv:
if values:
new_fields.append(
(key, values.pop(0))
)
else:
new_fields.append(field)
while values:
new_fields.append(
(key, values.pop(0))
)
self.fields = tuple(new_fields)
def add(self, key, value):
"""
Add an additional value for the given key at the bottom.
"""
self.insert(len(self.fields), key, value)
def insert(self, index, key, value):
"""
Insert an additional value for the given key at the specified position.
"""
item = (key, value)
self.fields = self.fields[:index] + (item,) + self.fields[index:]
def keys(self, multi=False):
"""
Get all keys.
Args:
multi(bool):
If True, one key per value will be returned.
If False, duplicate keys will only be returned once.
"""
return (
k
for k, _ in self.items(multi)
)
def values(self, multi=False):
"""
Get all values.
Args:
multi(bool):
If True, all values will be returned.
If False, only the first value per key will be returned.
"""
return (
v
for _, v in self.items(multi)
)
def items(self, multi=False):
"""
Get all (key, value) tuples.
Args:
multi(bool):
If True, all (key, value) pairs will be returned
If False, only the first (key, value) pair per unique key will be returned.
"""
if multi:
return self.fields
else:
return super(MultiDict, self).items()
def to_dict(self):
"""
Get the MultiDict as a plain Python dict.
Keys with multiple values are returned as lists.
Example:
.. code-block:: python
# Simple dict with duplicate values.
>>> d
MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
>>> d.to_dict()
{
"name": "value",
"a": ["false", "42"]
}
"""
d = {}
for key in self:
values = self.get_all(key)
if len(values) == 1:
d[key] = values[0]
else:
d[key] = values
return d
def get_state(self):
return self.fields
def set_state(self, state):
self.fields = tuple(tuple(x) for x in state)
@classmethod
def from_state(cls, state):
return cls(tuple(x) for x in state)
@six.add_metaclass(ABCMeta)
class ImmutableMultiDict(MultiDict):
def _immutable(self, *_):
raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
__delitem__ = set_all = insert = _immutable
def with_delitem(self, key):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).__delitem__(key)
return ret
def with_set_all(self, key, values):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).set_all(key, values)
return ret
def with_insert(self, index, key, value):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).insert(index, key, value)
return ret