Refactored keypath dict and util.

This commit is contained in:
Fabio Caccamo 2019-10-04 15:55:10 +02:00
parent 4d9a9b6602
commit a933ee3bdd
2 changed files with 31 additions and 30 deletions

View File

@ -15,35 +15,19 @@ class KeypathDict(dict):
keys = keypath_util.all_keys(d)
keypath_util.check_keys(keys, self._keypath_separator)
def _follow_keys(self, keys):
return keypath_util.follow_keys(self, keys)
def _join_keys(self, keys):
return keypath_util.join_keys(keys, self._keypath_separator)
def _split_keys(self, key):
return keypath_util.split_keys(key, self._keypath_separator)
def _walk_keys(self, keys):
item_keys = keys[:-1]
item_key = keys[-1]
item_parent = self
i = 0
j = len(item_keys)
while i < j:
key = item_keys[i]
try:
if item_parent is self:
item_parent = super(KeypathDict, self).__getitem__(key)
else:
item_parent = item_parent.__getitem__(key)
except KeyError:
item_parent = None
break
i += 1
return (item_parent, item_key, )
def __contains__(self, key):
keys = self._split_keys(key)
if len(keys) > 1:
item_parent, item_key = self._walk_keys(keys)
item_parent, item_key = self._follow_keys(keys)
if isinstance(item_parent, dict):
if item_parent.__contains__(item_key):
return True
@ -57,7 +41,7 @@ class KeypathDict(dict):
def __delitem__(self, key):
keys = self._split_keys(key)
if len(keys) > 1:
item_parent, item_key = self._walk_keys(keys)
item_parent, item_key = self._follow_keys(keys)
if isinstance(item_parent, dict):
item_parent.__delitem__(item_key)
else:
@ -69,7 +53,7 @@ class KeypathDict(dict):
keys = self._split_keys(key)
value = None
if len(keys) > 1:
item_parent, item_key = self._walk_keys(keys)
item_parent, item_key = self._follow_keys(keys)
if isinstance(item_parent, dict):
return item_parent.__getitem__(item_key)
else:
@ -112,7 +96,7 @@ class KeypathDict(dict):
def get(self, key, default=None):
keys = self._split_keys(key)
if len(keys) > 1:
item_parent, item_key = self._walk_keys(keys)
item_parent, item_key = self._follow_keys(keys)
if isinstance(item_parent, dict):
return item_parent.get(item_key, default)
else:
@ -125,11 +109,11 @@ class KeypathDict(dict):
return []
def walk_keypaths(root, path):
keypaths = []
for key, val in root.items():
for key, value in root.items():
keys = path + [key]
keypaths += [self._join_keys(keys)]
if isinstance(val, dict):
keypaths += walk_keypaths(val, keys)
if isinstance(value, dict):
keypaths += walk_keypaths(value, keys)
return keypaths
keypaths = walk_keypaths(self, [])
keypaths.sort()
@ -147,7 +131,7 @@ class KeypathDict(dict):
default = None
keys = self._split_keys(key)
if len(keys) > 1:
item_parent, item_key = self._walk_keys(keys)
item_parent, item_key = self._follow_keys(keys)
if isinstance(item_parent, dict):
if default_arg:
return item_parent.pop(item_key, default)

View File

@ -5,11 +5,11 @@ from six import string_types
def all_keys(d):
keys = []
for key, val in d.items():
for key, value in d.items():
if key not in keys:
keys.append(key)
if isinstance(val, dict):
keys += all_keys(val)
if isinstance(value, dict):
keys += all_keys(value)
return keys
@ -23,6 +23,23 @@ def check_keys(keys, separator):
'\'{}\'.'.format(separator))
def follow_keys(d, keys):
item_keys = keys[:-1]
item_key = keys[-1]
item_parent = d
i = 0
j = len(item_keys)
while i < j:
key = item_keys[i]
try:
item_parent = item_parent[key]
except KeyError:
item_parent = None
break
i += 1
return (item_parent, item_key, )
def join_keys(keys, separator):
return separator.join(keys)