from __future__ import absolute_import, print_function, division from abc import ABCMeta, abstractmethod from typing import Tuple, TypeVar try: from collections.abc import MutableMapping except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 import six from .utils import Serializable @six.add_metaclass(ABCMeta) class MultiDict(MutableMapping, Serializable): def __init__(self, fields=None): # it is important for us that .fields is immutable, so that we can easily # detect changes to it. self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] def __repr__(self): fields = tuple( repr(field) for field in self.fields ) return "{cls}[{fields}]".format( cls=type(self).__name__, fields=", ".join(fields) ) @staticmethod @abstractmethod def _reduce_values(values): """ If a user accesses multidict["foo"], this method reduces all values for "foo" to a single value that is returned. For example, HTTP headers are folded, whereas we will just take the first cookie we found with that name. """ @staticmethod @abstractmethod def _kconv(key): """ This method converts a key to its canonical representation. For example, HTTP headers are case-insensitive, so this method returns key.lower(). """ def __getitem__(self, key): values = self.get_all(key) if not values: raise KeyError(key) return self._reduce_values(values) def __setitem__(self, key, value): self.set_all(key, [value]) def __delitem__(self, key): if key not in self: raise KeyError(key) key = self._kconv(key) self.fields = tuple( field for field in self.fields if key != self._kconv(field[0]) ) def __iter__(self): seen = set() for key, _ in self.fields: key_kconv = self._kconv(key) if key_kconv not in seen: seen.add(key_kconv) yield key def __len__(self): return len(set(self._kconv(key) for key, _ in self.fields)) def __eq__(self, other): if isinstance(other, MultiDict): return self.fields == other.fields return False def __ne__(self, other): return not self.__eq__(other) def get_all(self, key): """ Return the list of all values for a given key. If that key is not in the MultiDict, the return value will be an empty list. """ key = self._kconv(key) return [ value for k, value in self.fields if self._kconv(k) == key ] def set_all(self, key, values): """ Remove the old values for a key and add new ones. """ key_kconv = self._kconv(key) new_fields = [] for field in self.fields: if self._kconv(field[0]) == key_kconv: if values: new_fields.append( (key, values.pop(0)) ) else: new_fields.append(field) while values: new_fields.append( (key, values.pop(0)) ) self.fields = tuple(new_fields) def add(self, key, value): """ Add an additional value for the given key at the bottom. """ self.insert(len(self.fields), key, value) def insert(self, index, key, value): """ Insert an additional value for the given key at the specified position. """ item = (key, value) self.fields = self.fields[:index] + (item,) + self.fields[index:] def keys(self, multi=False): """ Get all keys. Args: multi(bool): If True, one key per value will be returned. If False, duplicate keys will only be returned once. """ return ( k for k, _ in self.items(multi) ) def values(self, multi=False): """ Get all values. Args: multi(bool): If True, all values will be returned. If False, only the first value per key will be returned. """ return ( v for _, v in self.items(multi) ) def items(self, multi=False): """ Get all (key, value) tuples. Args: multi(bool): If True, all (key, value) pairs will be returned If False, only the first (key, value) pair per unique key will be returned. """ if multi: return self.fields else: return super(MultiDict, self).items() def to_dict(self): """ Get the MultiDict as a plain Python dict. Keys with multiple values are returned as lists. Example: .. code-block:: python # Simple dict with duplicate values. >>> d MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] >>> d.to_dict() { "name": "value", "a": ["false", "42"] } """ d = {} for key in self: values = self.get_all(key) if len(values) == 1: d[key] = values[0] else: d[key] = values return d def get_state(self): return self.fields def set_state(self, state): self.fields = tuple(tuple(x) for x in state) @classmethod def from_state(cls, state): return cls(tuple(x) for x in state) @six.add_metaclass(ABCMeta) class ImmutableMultiDict(MultiDict): def _immutable(self, *_): raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) __delitem__ = set_all = insert = _immutable def with_delitem(self, key): """ Returns: An updated ImmutableMultiDict. The original object will not be modified. """ ret = self.copy() super(ImmutableMultiDict, ret).__delitem__(key) return ret def with_set_all(self, key, values): """ Returns: An updated ImmutableMultiDict. The original object will not be modified. """ ret = self.copy() super(ImmutableMultiDict, ret).set_all(key, values) return ret def with_insert(self, index, key, value): """ Returns: An updated ImmutableMultiDict. The original object will not be modified. """ ret = self.copy() super(ImmutableMultiDict, ret).insert(index, key, value) return ret