""" Tests for `attr._funcs`. """ from __future__ import absolute_import, division, print_function from collections import Mapping, OrderedDict, Sequence import pytest from hypothesis import HealthCheck, assume, given, settings from hypothesis import strategies as st import attr from attr import asdict, assoc, astuple, evolve, fields, has from attr._compat import TYPE from attr.exceptions import AttrsAttributeNotFoundError from attr.validators import instance_of from .utils import nested_classes, simple_classes MAPPING_TYPES = (dict, OrderedDict) SEQUENCE_TYPES = (list, tuple) class TestAsDict(object): """ Tests for `asdict`. """ @given(st.sampled_from(MAPPING_TYPES)) def test_shallow(self, C, dict_factory): """ Shallow asdict returns correct dict. """ assert { "x": 1, "y": 2, } == asdict(C(x=1, y=2), False, dict_factory=dict_factory) @given(st.sampled_from(MAPPING_TYPES)) def test_recurse(self, C, dict_class): """ Deep asdict returns correct dict. """ assert { "x": {"x": 1, "y": 2}, "y": {"x": 3, "y": 4}, } == asdict(C( C(1, 2), C(3, 4), ), dict_factory=dict_class) @given(nested_classes, st.sampled_from(MAPPING_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_property(self, cls, dict_class): """ Property tests for recursive asdict. """ obj = cls() obj_dict = asdict(obj, dict_factory=dict_class) def assert_proper_dict_class(obj, obj_dict): assert isinstance(obj_dict, dict_class) for field in fields(obj.__class__): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_dict_class(field_val, obj_dict[field.name]) elif isinstance(field_val, Sequence): dict_val = obj_dict[field.name] for item, item_dict in zip(field_val, dict_val): if has(item.__class__): assert_proper_dict_class(item, item_dict) elif isinstance(field_val, Mapping): # This field holds a dictionary. assert isinstance(obj_dict[field.name], dict_class) for key, val in field_val.items(): if has(val.__class__): assert_proper_dict_class(val, obj_dict[field.name][key]) assert_proper_dict_class(obj, obj_dict) @given(st.sampled_from(MAPPING_TYPES)) def test_filter(self, C, dict_factory): """ Attributes that are supposed to be skipped are skipped. """ assert { "x": {"x": 1}, } == asdict(C( C(1, 2), C(3, 4), ), filter=lambda a, v: a.name != "y", dict_factory=dict_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert { "x": 1, "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"], } == asdict(C(1, container([C(2, 3), C(4, 5), "a"]))) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert { "x": 1, "y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]), } == asdict(C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True) @given(st.sampled_from(MAPPING_TYPES)) def test_dicts(self, C, dict_factory): """ If recurse is True, also recurse into dicts. """ res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory) assert { "x": 1, "y": {"a": {"x": 4, "y": 5}}, } == res assert isinstance(res, dict_factory) @given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES)) def test_roundtrip(self, cls, dict_class): """ Test dumping to dicts and back for Hypothesis-generated classes. Private attributes don't round-trip (the attribute name is different than the initializer argument). """ instance = cls() dict_instance = asdict(instance, dict_factory=dict_class) assert isinstance(dict_instance, dict_class) roundtrip_instance = cls(**dict_instance) assert instance == roundtrip_instance @given(simple_classes()) def test_asdict_preserve_order(self, cls): """ Field order should be preserved when dumping to OrderedDicts. """ instance = cls() dict_instance = asdict(instance, dict_factory=OrderedDict) assert [a.name for a in fields(cls)] == list(dict_instance.keys()) class TestAsTuple(object): """ Tests for `astuple`. """ @given(st.sampled_from(SEQUENCE_TYPES)) def test_shallow(self, C, tuple_factory): """ Shallow astuple returns correct dict. """ assert (tuple_factory([1, 2]) == astuple(C(x=1, y=2), False, tuple_factory=tuple_factory)) @given(st.sampled_from(SEQUENCE_TYPES)) def test_recurse(self, C, tuple_factory): """ Deep astuple returns correct tuple. """ assert (tuple_factory([tuple_factory([1, 2]), tuple_factory([3, 4])]) == astuple(C( C(1, 2), C(3, 4), ), tuple_factory=tuple_factory)) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_property(self, cls, tuple_class): """ Property tests for recursive astuple. """ obj = cls() obj_tuple = astuple(obj, tuple_factory=tuple_class) def assert_proper_tuple_class(obj, obj_tuple): assert isinstance(obj_tuple, tuple_class) for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_tuple_class(field_val, obj_tuple[index]) assert_proper_tuple_class(obj, obj_tuple) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_recurse_retain(self, cls, tuple_class): """ Property tests for asserting collection types are retained. """ obj = cls() obj_tuple = astuple(obj, tuple_factory=tuple_class, retain_collection_types=True) def assert_proper_col_class(obj, obj_tuple): # Iterate over all attributes, and if they are lists or mappings # in the original, assert they are the same class in the dumped. for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_col_class(field_val, obj_tuple[index]) elif isinstance(field_val, (list, tuple)): # This field holds a sequence of something. expected_type = type(obj_tuple[index]) assert type(field_val) is expected_type # noqa: E721 for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]): if has(obj_e.__class__): assert_proper_col_class(obj_e, obj_tuple_e) elif isinstance(field_val, dict): orig = field_val tupled = obj_tuple[index] assert type(orig) is type(tupled) # noqa: E721 for obj_e, obj_tuple_e in zip(orig.items(), tupled.items()): if has(obj_e[0].__class__): # Dict key assert_proper_col_class(obj_e[0], obj_tuple_e[0]) if has(obj_e[1].__class__): # Dict value assert_proper_col_class(obj_e[1], obj_tuple_e[1]) assert_proper_col_class(obj, obj_tuple) @given(st.sampled_from(SEQUENCE_TYPES)) def test_filter(self, C, tuple_factory): """ Attributes that are supposed to be skipped are skipped. """ assert tuple_factory([tuple_factory([1, ]), ]) == astuple(C( C(1, 2), C(3, 4), ), filter=lambda a, v: a.name != "y", tuple_factory=tuple_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert ((1, [(2, 3), (4, 5), "a"]) == astuple(C(1, container([C(2, 3), C(4, 5), "a"]))) ) @given(st.sampled_from(SEQUENCE_TYPES)) def test_dicts(self, C, tuple_factory): """ If recurse is True, also recurse into dicts. """ res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory) assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res assert isinstance(res, tuple_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert ( (1, container([(2, 3), (4, 5), "a"])) == astuple(C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True)) @given(container=st.sampled_from(MAPPING_TYPES)) def test_dicts_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert ( (1, container({"a": (4, 5)})) == astuple(C(1, container({"a": C(4, 5)})), retain_collection_types=True)) @given(simple_classes(), st.sampled_from(SEQUENCE_TYPES)) def test_roundtrip(self, cls, tuple_class): """ Test dumping to tuple and back for Hypothesis-generated classes. """ instance = cls() tuple_instance = astuple(instance, tuple_factory=tuple_class) assert isinstance(tuple_instance, tuple_class) roundtrip_instance = cls(*tuple_instance) assert instance == roundtrip_instance class TestHas(object): """ Tests for `has`. """ def test_positive(self, C): """ Returns `True` on decorated classes. """ assert has(C) def test_positive_empty(self): """ Returns `True` on decorated classes even if there are no attributes. """ @attr.s class D(object): pass assert has(D) def test_negative(self): """ Returns `False` on non-decorated classes. """ assert not has(object) class TestAssoc(object): """ Tests for `assoc`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attr.s(slots=slots, frozen=frozen) class C(object): pass i1 = C() with pytest.deprecated_call(): i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() with pytest.deprecated_call(): i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) change_dict = {name: data.draw(st.integers()) for name in chosen_names} with pytest.deprecated_call(): changed = assoc(original, **change_dict) for k, v in change_dict.items(): assert getattr(changed, k) == v @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(AttrsAttributeNotFoundError) as e, \ pytest.deprecated_call(): assoc(C(), aaaa=2) assert ( "aaaa is not an attrs attribute on {cls!r}.".format(cls=C), ) == e.value.args def test_frozen(self): """ Works on frozen classes. """ @attr.s(frozen=True) class C(object): x = attr.ib() y = attr.ib() with pytest.deprecated_call(): assert C(3, 2) == assoc(C(1, 2), x=3) def test_warning(self): """ DeprecationWarning points to the correct file. """ @attr.s class C(object): x = attr.ib() with pytest.warns(DeprecationWarning) as wi: assert C(2) == assoc(C(1), x=2) assert __file__ == wi.list[0].filename class TestEvolve(object): """ Tests for `evolve`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attr.s(slots=slots, frozen=frozen) class C(object): pass i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) # We pay special attention to private attributes, they should behave # like in `__init__`. change_dict = {name.replace('_', ''): data.draw(st.integers()) for name in chosen_names} changed = evolve(original, **change_dict) for name in chosen_names: assert getattr(changed, name) == change_dict[name.replace('_', '')] @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(TypeError) as e: evolve(C(), aaaa=2) expected = "__init__() got an unexpected keyword argument 'aaaa'" assert (expected,) == e.value.args def test_validator_failure(self): """ TypeError isn't swallowed when validation fails within evolve. """ @attr.s class C(object): a = attr.ib(validator=instance_of(int)) with pytest.raises(TypeError) as e: evolve(C(a=1), a="some string") m = e.value.args[0] assert m.startswith("'a' must be <{type} 'int'>".format(type=TYPE)) def test_private(self): """ evolve() acts as `__init__` with regards to private attributes. """ @attr.s class C(object): _a = attr.ib() assert evolve(C(1), a=2)._a == 2 with pytest.raises(TypeError): evolve(C(1), _a=2) with pytest.raises(TypeError): evolve(C(1), a=3, _a=2) def test_non_init_attrs(self): """ evolve() handles `init=False` attributes. """ @attr.s class C(object): a = attr.ib() b = attr.ib(init=False, default=0) assert evolve(C(1), a=2).a == 2