diff --git a/benedict/dicts/base/base_dict.py b/benedict/dicts/base/base_dict.py index 917010c..5788b70 100644 --- a/benedict/dicts/base/base_dict.py +++ b/benedict/dicts/base/base_dict.py @@ -4,118 +4,128 @@ class BaseDict(dict): _dict = None + _pointer = False def __init__(self, *args, **kwargs): if len(args) == 1 and isinstance(args[0], dict): - self._dict = args[0] + self._dict = args[0].dict() if issubclass( + type(args[0]), BaseDict) else args[0] + self._pointer = True super(BaseDict, self).__init__(self._dict) return self._dict = None + self._pointer = False super(BaseDict, self).__init__(*args, **kwargs) - def _is_pointer(self): - return self._dict is not self and self._dict is not None + def __bool__(self): + if self._pointer: + return bool(self._dict) + return len(self.keys()) > 0 def __contains__(self, key): - if self._is_pointer(): + if self._pointer: return key in self._dict return super(BaseDict, self).__contains__(key) def __delitem__(self, key): - if self._is_pointer(): + if self._pointer: del self._dict[key] return super(BaseDict, self).__delitem__(key) def __eq__(self, other): - if self._is_pointer(): + if self._pointer: return self._dict == other return super(BaseDict, self).__eq__(other) def __getitem__(self, key): - if self._is_pointer(): + if self._pointer: return self._dict[key] return super(BaseDict, self).__getitem__(key) def __iter__(self): - if self._is_pointer(): + if self._pointer: return iter(self._dict) return super(BaseDict, self).__iter__() def __len__(self): - if self._is_pointer(): + if self._pointer: return len(self._dict) return super(BaseDict, self).__len__() + def __nonzero__(self): + # python 2 + return self.__bool__() + def __repr__(self): - if self._is_pointer(): + if self._pointer: return repr(self._dict) return super(BaseDict, self).__repr__() def __setitem__(self, key, value): - if self._is_pointer(): + if self._pointer: self._dict[key] = value return super(BaseDict, self).__setitem__(key, value) def __str__(self): - if self._is_pointer(): + if self._pointer: return str(self._dict) return super(BaseDict, self).__str__() def __unicode__(self): - if self._is_pointer(): + if self._pointer: return unicode(self._dict) - return super(BaseDict, self).__unicode__() + return '{}'.format(self) def clear(self): - if self._is_pointer(): + if self._pointer: self._dict.clear() return super(BaseDict, self).clear() def copy(self): - if self._is_pointer(): + if self._pointer: return self._dict.copy() return super(BaseDict, self).copy() def dict(self): - if self._is_pointer(): + if self._pointer: return self._dict return self def get(self, key, default=None): - if self._is_pointer(): + if self._pointer: return self._dict.get(key, default) return super(BaseDict, self).get(key, default) def items(self): - if self._is_pointer(): + if self._pointer: return self._dict.items() return super(BaseDict, self).items() def keys(self): - if self._is_pointer(): + if self._pointer: return self._dict.keys() return super(BaseDict, self).keys() def pop(self, key, *args): - if self._is_pointer(): + if self._pointer: return self._dict.pop(key, *args) return super(BaseDict, self).pop(key, *args) def setdefault(self, key, default=None): - if self._is_pointer(): + if self._pointer: return self._dict.setdefault(key, default) return super(BaseDict, self).setdefault(key, default) def update(self, other): - if self._is_pointer(): + if self._pointer: self._dict.update(other) return super(BaseDict, self).update(other) def values(self): - if self._is_pointer(): + if self._pointer: return self._dict.values() return super(BaseDict, self).values() diff --git a/tests/dicts/base/test_base_dict.py b/tests/dicts/base/test_base_dict.py index d571384..c16ec99 100644 --- a/tests/dicts/base/test_base_dict.py +++ b/tests/dicts/base/test_base_dict.py @@ -8,6 +8,7 @@ except ImportError: from collections import Iterable import copy +import sys import unittest @@ -185,6 +186,17 @@ class base_dict_test_case(unittest.TestCase): self.assertEqual(str(d), str(b)) self.assertEqual(b, b.dict()) + @unittest.skipIf(sys.version_info[0] > 2, 'No unicode in Python 3') + def test__unicode__(self): + d = BaseDict() + d['name'] = 'pythòn-bènèdìçt' + print(unicode(d)) + + @unittest.skipIf(sys.version_info[0] > 2, 'No unicode in Python > 2') + def test__unicode__with_pointer(self): + d = BaseDict({ 'name': 'pythòn-bènèdìçt' }) + print(unicode(d)) + def test_clear(self): d = { 'a':1, 'b':2, 'c':3 } b = BaseDict() @@ -337,14 +349,14 @@ class base_dict_test_case(unittest.TestCase): self.assertEqual(b, b.dict()) def test_setdefault(self): - d = { 'a':1, 'b':2, 'c':3 } - b = BaseDict(d) + b = BaseDict() + b['a'] = 1 + b['b'] = 2 + b['c'] = 3 v = b.setdefault('c', 4) self.assertEqual(v, 3) v = b.setdefault('d', 4) self.assertEqual(v, 4) - self.assertEqual(d, { 'a':1, 'b':2, 'c':3, 'd':4 }) - self.assertTrue(b == d) self.assertEqual(b, b.dict()) def test_setdefault_with_pointer(self):