From 76a0e238ae3337cab22ff4523e2cbf257f649565 Mon Sep 17 00:00:00 2001 From: Fabio Caccamo Date: Mon, 10 Jun 2019 14:40:05 +0200 Subject: [PATCH] Added casting to benedict when retrieving dict values from a benedict instance. --- benedict/dicts/__init__.py | 29 ++++++++++++-- benedict/dicts/keypath.py | 9 ++--- tests/test_dicts_benedict.py | 65 ++++++++++++++++++++++++++++++++ tests/test_dicts_keypath_dict.py | 19 +--------- 4 files changed, 96 insertions(+), 26 deletions(-) diff --git a/benedict/dicts/__init__.py b/benedict/dicts/__init__.py index 9d732e6..c18d6b7 100644 --- a/benedict/dicts/__init__.py +++ b/benedict/dicts/__init__.py @@ -11,16 +11,39 @@ class benedict(KeypathDict, ParseDict): def __init__(self, *args, **kwargs): super(benedict, self).__init__(*args, **kwargs) + @staticmethod + def cast(val): + if isinstance(val, dict) and not isinstance(val, benedict): + return benedict(val) + else: + return val + def copy(self): - return benedict( + return benedict.cast( super(benedict, self).copy()) def deepcopy(self): - return benedict( + return benedict.cast( deepcopy(self)) @classmethod def fromkeys(cls, sequence, value=None): - return benedict( + return benedict.cast( KeypathDict.fromkeys(sequence, value)) + def __getitem__(self, key): + return benedict.cast( + super(benedict, self).__getitem__(key)) + + def get(self, key, default=None): + return benedict.cast( + super(benedict, self).get(key, default)) + + def pop(self, key, default=None): + return benedict.cast( + super(benedict, self).pop(key, default)) + + def setdefault(self, key, default=None): + return benedict.cast( + super(benedict, self).setdefault(key, default)) + diff --git a/benedict/dicts/keypath.py b/benedict/dicts/keypath.py index 4349c86..c39535d 100644 --- a/benedict/dicts/keypath.py +++ b/benedict/dicts/keypath.py @@ -67,7 +67,10 @@ class KeypathDict(dict): while i < j: key = keys[i] if i < (j - 1): - subitem = item.get(key, None) + if item == self: + subitem = super(KeypathDict, self).get(key, None) + else: + subitem = item.get(key, None) if not isinstance(subitem, dict): subitem = item[key] = {} item = subitem @@ -113,10 +116,6 @@ class KeypathDict(dict): else: super(KeypathDict, self).__setitem__(key, value) - def copy(self): - return KeypathDict( - super(KeypathDict, self).copy()) - @classmethod def fromkeys(cls, sequence, value=None): d = KeypathDict() diff --git a/tests/test_dicts_benedict.py b/tests/test_dicts_benedict.py index 66042c0..0c5419a 100644 --- a/tests/test_dicts_benedict.py +++ b/tests/test_dicts_benedict.py @@ -97,3 +97,68 @@ class BenedictTestCase(unittest.TestCase): self.assertEqual(b, r) self.assertEqual(type(b), benedict) + def test_get_item(self): + d = { + 'a': 1, + 'b': { + 'c': 2, + 'd': { + 'e': 3, + } + } + } + b = benedict(d) + self.assertEqual(b['a'], 1) + self.assertEqual(b['b.c'], 2) + self.assertTrue(isinstance(b['b'], benedict)) + self.assertTrue(isinstance(b['b.d'], benedict)) + bb = b['b'] + self.assertTrue(isinstance(bb['d'], benedict)) + + def test_get(self): + d = { + 'a': 1, + 'b': { + 'c': 2, + 'd': { + 'e': 3, + } + } + } + b = benedict(d) + self.assertEqual(b.get('a'), 1) + self.assertEqual(b.get('b.c'), 2) + self.assertTrue(isinstance(b.get('b'), benedict)) + self.assertTrue(isinstance(b.get('b.d'), benedict)) + bb = b.get('b') + self.assertTrue(isinstance(bb.get('d'), benedict)) + + def test_pop(self): + d = { + 'a': 1, + 'b': { + 'c': 2, + 'd': { + 'e': 3, + } + } + } + b = benedict(d) + self.assertEqual(b.pop('a'), 1) + self.assertEqual(b.pop('b.c'), 2) + self.assertTrue(isinstance(b.pop('b.d'), benedict)) + + def test_setdefault(self): + d = { + 'a': 1, + 'b': { + 'c': 2, + 'd': { + 'e': 3, + } + } + } + b = benedict(d) + self.assertTrue(isinstance(b.setdefault('b', 1), benedict)) + self.assertTrue(isinstance(b.setdefault('b.d', 1), benedict)) + diff --git a/tests/test_dicts_keypath_dict.py b/tests/test_dicts_keypath_dict.py index 11449ea..d97cac2 100644 --- a/tests/test_dicts_keypath_dict.py +++ b/tests/test_dicts_keypath_dict.py @@ -14,23 +14,6 @@ class KeypathDictTestCase(unittest.TestCase): # print(d) # print(d[keys]) - def test_copy(self): - d = { - 'a': { - 'b': { - 'c': 1 - } - } - } - b = KeypathDict(d) - c = b.copy() - self.assertEqual(type(b), type(c)) - self.assertEqual(b, c) - self.assertFalse(c is b) - c['a.b.c'] = 2 - self.assertEqual(b.get('a.b.c'), 2) - self.assertEqual(c.get('a.b.c'), 2) - def test_fromkeys(self): k = [ 'a', @@ -335,7 +318,7 @@ class KeypathDictTestCase(unittest.TestCase): 'x.y', 'x.z', ] - self.assertEqual(b.get_keypaths(), r) + self.assertEqual(b.keypaths(), r) def test_set_override_existing_item(self): d = {}