python-benedict/benedict/dicts/keypath.py

157 lines
4.9 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
from benedict.utils import keypath_util
class KeypathDict(dict):
def __init__(self, *args, **kwargs):
self._separator = kwargs.pop('separator', '.')
super(KeypathDict, self).__init__(*args, **kwargs)
def _get_value_by_keys(self, keys):
i = 0
j = len(keys)
val = self
while i < j:
key = keys[i]
try:
val = val[key]
except KeyError:
val = None
break
i += 1
return val
def _get_value_context_by_keys(self, keys):
item_keys = keys[:-1]
item_key = keys[-1]
item_parent = self._get_value_by_keys(item_keys)
return (item_parent, item_key, )
def _has_value_by_keys(self, keys):
item_parent, item_key = self._get_value_context_by_keys(keys)
if isinstance(item_parent, dict):
if item_key in item_parent:
return True
else:
return False
else:
return False
def _set_value_by_keys(self, keys, value):
i = 0
j = len(keys)
item = self
while i < j:
key = keys[i]
if i < (j - 1):
if item == self:
subitem = super(KeypathDict, self).get(key, None)
else:
subitem = item.get(key, None)
if not isinstance(subitem, dict):
subitem = item[key] = {}
item = subitem
else:
item[key] = value
i += 1
def __contains__(self, key):
keys = keypath_util.split_keys(key, self._separator)
if len(keys) > 1:
return self._has_value_by_keys(keys)
else:
return super(KeypathDict, self).__contains__(key)
def __delitem__(self, key):
keys = keypath_util.split_keys(key, self._separator)
if len(keys) > 1:
item_parent, item_key = self._get_value_context_by_keys(keys)
if isinstance(item_parent, dict):
del item_parent[item_key]
else:
raise KeyError
else:
super(KeypathDict, self).__delitem__(key)
def __getitem__(self, key):
keys = keypath_util.split_keys(key, self._separator)
value = None
if len(keys) > 1:
item_parent, item_key = self._get_value_context_by_keys(keys)
if isinstance(item_parent, dict):
return item_parent[item_key]
else:
raise KeyError
else:
value = super(KeypathDict, self).__getitem__(key)
return value
def __setitem__(self, key, value):
keys = keypath_util.split_keys(key, self._separator)
if len(keys) > 1:
self._set_value_by_keys(keys, value)
else:
super(KeypathDict, self).__setitem__(key, value)
@classmethod
def fromkeys(cls, sequence, value=None):
d = KeypathDict()
for key in sequence:
d[key] = value
return d
def get(self, key, default=None):
keys = keypath_util.split_keys(key, self._separator)
if len(keys) > 1:
item_parent, item_key = self._get_value_context_by_keys(keys)
if isinstance(item_parent, dict):
return item_parent.get(item_key, default)
else:
return default
else:
return super(KeypathDict, self).get(key, default)
def keypaths(self):
def walk_keypaths(root, path):
keypaths = []
for key, val in root.items():
keypaths += [keypath_util.join_keys(path + [key], self._separator)]
if isinstance(val, dict):
keypaths += walk_keypaths(val, path + [key])
return keypaths
keypaths = walk_keypaths(self, [])
keypaths.sort()
return keypaths
def pop(self, key, default=None):
keys = keypath_util.split_keys(key, self._separator)
if len(keys) > 1:
item_parent, item_key = self._get_value_context_by_keys(keys)
if isinstance(item_parent, dict):
if default is None:
return item_parent.pop(item_key)
else:
return item_parent.pop(item_key, default)
else:
if default is None:
raise KeyError
else:
return default
else:
if default is None:
return super(KeypathDict, self).pop(key)
else:
return super(KeypathDict, self).pop(key, default)
def set(self, key, value):
self[key] = value
def setdefault(self, key, default=None):
if key not in self:
self[key] = default
return default
else:
return self[key]