""" Tests for `attr._funcs`. """ from __future__ import absolute_import, division, print_function from collections import OrderedDict, Sequence, Mapping import pytest from hypothesis import assume, given, strategies as st, settings, HealthCheck from .utils import simple_classes, nested_classes from attr import ( attr, attributes, asdict, assoc, astuple, evolve, fields, has, ) from attr.exceptions import AttrsAttributeNotFoundError from attr.validators import instance_of from attr._compat import TYPE MAPPING_TYPES = (dict, OrderedDict) SEQUENCE_TYPES = (list, tuple) class TestAsDict(object): """ Tests for `asdict`. """ @given(st.sampled_from(MAPPING_TYPES)) def test_shallow(self, C, dict_factory): """ Shallow asdict returns correct dict. """ assert { "x": 1, "y": 2, } == asdict(C(x=1, y=2), False, dict_factory=dict_factory) @given(st.sampled_from(MAPPING_TYPES)) def test_recurse(self, C, dict_class): """ Deep asdict returns correct dict. """ assert { "x": {"x": 1, "y": 2}, "y": {"x": 3, "y": 4}, } == asdict(C( C(1, 2), C(3, 4), ), dict_factory=dict_class) @given(nested_classes, st.sampled_from(MAPPING_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_property(self, cls, dict_class): """ Property tests for recursive asdict. """ obj = cls() obj_dict = asdict(obj, dict_factory=dict_class) def assert_proper_dict_class(obj, obj_dict): assert isinstance(obj_dict, dict_class) for field in fields(obj.__class__): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_dict_class(field_val, obj_dict[field.name]) elif isinstance(field_val, Sequence): dict_val = obj_dict[field.name] for item, item_dict in zip(field_val, dict_val): if has(item.__class__): assert_proper_dict_class(item, item_dict) elif isinstance(field_val, Mapping): # This field holds a dictionary. assert isinstance(obj_dict[field.name], dict_class) for key, val in field_val.items(): if has(val.__class__): assert_proper_dict_class(val, obj_dict[field.name][key]) assert_proper_dict_class(obj, obj_dict) @given(st.sampled_from(MAPPING_TYPES)) def test_filter(self, C, dict_factory): """ Attributes that are supposed to be skipped are skipped. """ assert { "x": {"x": 1}, } == asdict(C( C(1, 2), C(3, 4), ), filter=lambda a, v: a.name != "y", dict_factory=dict_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert { "x": 1, "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"], } == asdict(C(1, container([C(2, 3), C(4, 5), "a"]))) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert { "x": 1, "y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]), } == asdict(C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True) @given(st.sampled_from(MAPPING_TYPES)) def test_dicts(self, C, dict_factory): """ If recurse is True, also recurse into dicts. """ res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory) assert { "x": 1, "y": {"a": {"x": 4, "y": 5}}, } == res assert isinstance(res, dict_factory) @given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES)) def test_roundtrip(self, cls, dict_class): """ Test dumping to dicts and back for Hypothesis-generated classes. Private attributes don't round-trip (the attribute name is different than the initializer argument). """ instance = cls() dict_instance = asdict(instance, dict_factory=dict_class) assert isinstance(dict_instance, dict_class) roundtrip_instance = cls(**dict_instance) assert instance == roundtrip_instance @given(simple_classes()) def test_asdict_preserve_order(self, cls): """ Field order should be preserved when dumping to OrderedDicts. """ instance = cls() dict_instance = asdict(instance, dict_factory=OrderedDict) assert [a.name for a in fields(cls)] == list(dict_instance.keys()) class TestAsTuple(object): """ Tests for `astuple`. """ @given(st.sampled_from(SEQUENCE_TYPES)) def test_shallow(self, C, tuple_factory): """ Shallow astuple returns correct dict. """ assert (tuple_factory([1, 2]) == astuple(C(x=1, y=2), False, tuple_factory=tuple_factory)) @given(st.sampled_from(SEQUENCE_TYPES)) def test_recurse(self, C, tuple_factory): """ Deep astuple returns correct tuple. """ assert (tuple_factory([tuple_factory([1, 2]), tuple_factory([3, 4])]) == astuple(C( C(1, 2), C(3, 4), ), tuple_factory=tuple_factory)) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_property(self, cls, tuple_class): """ Property tests for recursive astuple. """ obj = cls() obj_tuple = astuple(obj, tuple_factory=tuple_class) def assert_proper_tuple_class(obj, obj_tuple): assert isinstance(obj_tuple, tuple_class) for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_tuple_class(field_val, obj_tuple[index]) assert_proper_tuple_class(obj, obj_tuple) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_retain(self, cls, tuple_class): """ Property tests for asserting collection types are retained. """ obj = cls() obj_tuple = astuple(obj, tuple_factory=tuple_class, retain_collection_types=True) def assert_proper_col_class(obj, obj_tuple): # Iterate over all attributes, and if they are lists or mappings # in the original, assert they are the same class in the dumped. for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_col_class(field_val, obj_tuple[index]) elif isinstance(field_val, (list, tuple)): # This field holds a sequence of something. expected_type = type(obj_tuple[index]) assert type(field_val) is expected_type # noqa: E721 for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]): if has(obj_e.__class__): assert_proper_col_class(obj_e, obj_tuple_e) elif isinstance(field_val, dict): orig = field_val tupled = obj_tuple[index] assert type(orig) is type(tupled) # noqa: E721 for obj_e, obj_tuple_e in zip(orig.items(), tupled.items()): if has(obj_e[0].__class__): # Dict key assert_proper_col_class(obj_e[0], obj_tuple_e[0]) if has(obj_e[1].__class__): # Dict value assert_proper_col_class(obj_e[1], obj_tuple_e[1]) assert_proper_col_class(obj, obj_tuple) @given(st.sampled_from(SEQUENCE_TYPES)) def test_filter(self, C, tuple_factory): """ Attributes that are supposed to be skipped are skipped. """ assert tuple_factory([tuple_factory([1, ]), ]) == astuple(C( C(1, 2), C(3, 4), ), filter=lambda a, v: a.name != "y", tuple_factory=tuple_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert ((1, [(2, 3), (4, 5), "a"]) == astuple(C(1, container([C(2, 3), C(4, 5), "a"]))) ) @given(st.sampled_from(SEQUENCE_TYPES)) def test_dicts(self, C, tuple_factory): """ If recurse is True, also recurse into dicts. """ res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory) assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res assert isinstance(res, tuple_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert ( (1, container([(2, 3), (4, 5), "a"])) == astuple(C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True)) @given(container=st.sampled_from(MAPPING_TYPES)) def test_dicts_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert ( (1, container({"a": (4, 5)})) == astuple(C(1, container({"a": C(4, 5)})), retain_collection_types=True)) @given(simple_classes(), st.sampled_from(SEQUENCE_TYPES)) def test_roundtrip(self, cls, tuple_class): """ Test dumping to tuple and back for Hypothesis-generated classes. """ instance = cls() tuple_instance = astuple(instance, tuple_factory=tuple_class) assert isinstance(tuple_instance, tuple_class) roundtrip_instance = cls(*tuple_instance) assert instance == roundtrip_instance class TestHas(object): """ Tests for `has`. """ def test_positive(self, C): """ Returns `True` on decorated classes. """ assert has(C) def test_positive_empty(self): """ Returns `True` on decorated classes even if there are no attributes. """ @attributes class D(object): pass assert has(D) def test_negative(self): """ Returns `False` on non-decorated classes. """ assert not has(object) class TestAssoc(object): """ Tests for `assoc`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attributes(slots=slots, frozen=frozen) class C(object): pass i1 = C() i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) change_dict = {name: data.draw(st.integers()) for name in chosen_names} changed = assoc(original, **change_dict) for k, v in change_dict.items(): assert getattr(changed, k) == v @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(AttrsAttributeNotFoundError) as e: assoc(C(), aaaa=2) assert ( "aaaa is not an attrs attribute on {cls!r}.".format(cls=C), ) == e.value.args def test_frozen(self): """ Works on frozen classes. """ @attributes(frozen=True) class C(object): x = attr() y = attr() assert C(3, 2) == assoc(C(1, 2), x=3) class TestEvolve(object): """ Tests for `evolve`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attributes(slots=slots, frozen=frozen) class C(object): pass i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) # We pay special attention to private attributes, they should behave # like in `__init__`. change_dict = {name.replace('_', ''): data.draw(st.integers()) for name in chosen_names} changed = evolve(original, **change_dict) for name in chosen_names: assert getattr(changed, name) == change_dict[name.replace('_', '')] @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(TypeError) as e: evolve(C(), aaaa=2) expected = "__init__() got an unexpected keyword argument 'aaaa'" assert (expected,) == e.value.args def test_validator_failure(self): """ TypeError isn't swallowed when validation fails within evolve. """ @attributes class C(object): a = attr(validator=instance_of(int)) with pytest.raises(TypeError) as e: evolve(C(a=1), a="some string") m = e.value.args[0] assert m.startswith("'a' must be <{type} 'int'>".format(type=TYPE)) def test_private(self): """ evolve() acts as `__init__` with regards to private attributes. """ @attributes class C(object): _a = attr() assert evolve(C(1), a=2)._a == 2 with pytest.raises(TypeError): evolve(C(1), _a=2) with pytest.raises(TypeError): evolve(C(1), a=3, _a=2) def test_non_init_attrs(self): """ evolve() handles `init=False` attributes. """ @attributes class C(object): a = attr() b = attr(init=False, default=0) assert evolve(C(1), a=2).a == 2