From cd80208187226caea58aafe6c8fbc5f43827f23c Mon Sep 17 00:00:00 2001 From: Tin Tvrtkovic Date: Sat, 7 May 2016 19:55:09 +0200 Subject: [PATCH] Initial dict_factory support for asdict. --- src/attr/_funcs.py | 7 +++++-- tests/__init__.py | 20 ++++++++++++++++++- tests/test_funcs.py | 48 +++++++++++++++++++++++++++++++++++++-------- 3 files changed, 64 insertions(+), 11 deletions(-) diff --git a/src/attr/_funcs.py b/src/attr/_funcs.py index 15daae69..b2e049a2 100644 --- a/src/attr/_funcs.py +++ b/src/attr/_funcs.py @@ -6,7 +6,7 @@ from ._compat import iteritems from ._make import Attribute, NOTHING, fields -def asdict(inst, recurse=True, filter=None): +def asdict(inst, recurse=True, filter=None, dict_factory=dict): """ Return the ``attrs`` attribute values of *i* as a dict. Optionally recurse into other ``attrs``-decorated classes. @@ -22,10 +22,13 @@ def asdict(inst, recurse=True, filter=None): value as the second argument. :type filer: callable + :param dict_factory: A callable to produce dictionaries from. + :type dict_factory: callable + :rtype: :class:`dict` """ attrs = fields(inst.__class__) - rv = {} + rv = dict_factory() for a in attrs: v = getattr(inst, a.name) if filter is not None and not filter(a, v): diff --git a/tests/__init__.py b/tests/__init__.py index 8fef9a96..85e7bc92 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,10 @@ from __future__ import absolute_import, division, print_function -from attr import Attribute +import string + +from hypothesis import strategies as st + +from attr import Attribute, ib from attr._make import NOTHING, make_class @@ -40,3 +44,17 @@ class TestSimpleClass(object): Each call returns a completely new class. """ assert simple_class() is not simple_class() + + +def create_class(attrs): + # What if we get more than len(string.ascii_lowercase) attributes? + return make_class('HypClass', dict(zip(string.ascii_lowercase, attrs))) + +bare_attrs = st.just(ib(default=None)) +int_attrs = st.integers().map(lambda i: ib(default=i)) +str_attrs = st.text().map(lambda s: ib(default=s)) +float_attrs = st.floats().map(lambda f: ib(default=f)) + +simple_attrs = st.one_of(bare_attrs, int_attrs, str_attrs, float_attrs) + +simple_classes = st.lists(simple_attrs).map(create_class) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index acb4961c..06cf82db 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -3,13 +3,19 @@ Tests for `attr._funcs`. """ from __future__ import absolute_import, division, print_function +from collections import OrderedDict import pytest +from hypothesis import given, strategies as st + +from . import simple_classes + from attr._funcs import ( asdict, assoc, has, + fields, ) from attr._make import ( attr, @@ -21,16 +27,18 @@ class TestAsDict(object): """ Tests for `asdict`. """ - def test_shallow(self, C): + @given(st.sampled_from([dict, OrderedDict])) + def test_shallow(self, C, dict_factory): """ Shallow asdict returns correct dict. """ assert { "x": 1, "y": 2, - } == asdict(C(x=1, y=2), False) + } == asdict(C(x=1, y=2), False, dict_factory=dict_factory) - def test_recurse(self, C): + @given(st.sampled_from([dict, OrderedDict])) + def test_recurse(self, C, dict_factory): """ Deep asdict returns correct dict. """ @@ -40,9 +48,10 @@ class TestAsDict(object): } == asdict(C( C(1, 2), C(3, 4), - )) + ), dict_factory=dict_factory) - def test_filter(self, C): + @given(st.sampled_from([dict, OrderedDict])) + def test_filter(self, C, dict_factory): """ Attributes that are supposed to be skipped are skipped. """ @@ -51,7 +60,7 @@ class TestAsDict(object): } == asdict(C( C(1, 2), C(3, 4), - ), filter=lambda a, v: a.name != "y") + ), filter=lambda a, v: a.name != "y", dict_factory=dict_factory) @pytest.mark.parametrize("container", [ list, @@ -66,14 +75,37 @@ class TestAsDict(object): "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"], } == asdict(C(1, container([C(2, 3), C(4, 5), "a"]))) - def test_dicts(self, C): + @given(st.sampled_from([dict, OrderedDict])) + def test_dicts(self, C, dict_factory): """ If recurse is True, also recurse into dicts. """ + res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory) assert { "x": 1, "y": {"a": {"x": 4, "y": 5}}, - } == asdict(C(1, {"a": C(4, 5)})) + } == res + assert isinstance(res, dict_factory) + + @given(simple_classes, st.sampled_from([dict, OrderedDict])) + def test_roundtrip(self, cls, dict_factory): + """Test roundtripping for Hypothesis-generated classes.""" + instance = cls() + dict_instance = asdict(instance, dict_factory=dict_factory) + + assert isinstance(dict_instance, dict_factory) + + roundtrip_instance = cls(**dict_instance) + + assert instance == roundtrip_instance + + @given(simple_classes) + def test_asdict_preserve_order(self, cls): + """When dumping to OrderedDict, field order should be preserved.""" + instance = cls() + dict_instance = asdict(instance, dict_factory=OrderedDict) + + assert [a.name for a in fields(cls)] == list(dict_instance.keys()) class TestHas(object):