# SPDX-License-Identifier: MIT """ Tests for `attr._funcs`. """ import re from collections import OrderedDict from typing import Generic, NamedTuple, TypeVar import pytest from hypothesis import assume, given from hypothesis import strategies as st import attr from attr import asdict, assoc, astuple, evolve, fields, has from attr._compat import Mapping, Sequence from attr.exceptions import AttrsAttributeNotFoundError from attr.validators import instance_of from .strategies import nested_classes, simple_classes MAPPING_TYPES = (dict, OrderedDict) SEQUENCE_TYPES = (list, tuple) @pytest.fixture(scope="session", name="C") def _C(): """ Return a simple but fully featured attrs class with an x and a y attribute. """ import attr @attr.s class C: x = attr.ib() y = attr.ib() return C class TestAsDict: """ Tests for `asdict`. """ @given(st.sampled_from(MAPPING_TYPES)) def test_shallow(self, C, dict_factory): """ Shallow asdict returns correct dict. """ assert {"x": 1, "y": 2} == asdict( C(x=1, y=2), False, dict_factory=dict_factory ) @given(st.sampled_from(MAPPING_TYPES)) def test_recurse(self, C, dict_class): """ Deep asdict returns correct dict. """ assert {"x": {"x": 1, "y": 2}, "y": {"x": 3, "y": 4}} == asdict( C(C(1, 2), C(3, 4)), dict_factory=dict_class ) def test_nested_lists(self, C): """ Test unstructuring deeply nested lists. """ inner = C(1, 2) outer = C([[inner]], None) assert {"x": [[{"x": 1, "y": 2}]], "y": None} == asdict(outer) def test_nested_dicts(self, C): """ Test unstructuring deeply nested dictionaries. """ inner = C(1, 2) outer = C({1: {2: inner}}, None) assert {"x": {1: {2: {"x": 1, "y": 2}}}, "y": None} == asdict(outer) @given(nested_classes, st.sampled_from(MAPPING_TYPES)) def test_recurse_property(self, cls, dict_class): """ Property tests for recursive asdict. """ obj = cls() obj_dict = asdict(obj, dict_factory=dict_class) def assert_proper_dict_class(obj, obj_dict): assert isinstance(obj_dict, dict_class) for field in fields(obj.__class__): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_dict_class(field_val, obj_dict[field.name]) elif isinstance(field_val, Sequence): dict_val = obj_dict[field.name] for item, item_dict in zip(field_val, dict_val): if has(item.__class__): assert_proper_dict_class(item, item_dict) elif isinstance(field_val, Mapping): # This field holds a dictionary. assert isinstance(obj_dict[field.name], dict_class) for key, val in field_val.items(): if has(val.__class__): assert_proper_dict_class( val, obj_dict[field.name][key] ) assert_proper_dict_class(obj, obj_dict) @given(st.sampled_from(MAPPING_TYPES)) def test_filter(self, C, dict_factory): """ Attributes that are supposed to be skipped are skipped. """ assert {"x": {"x": 1}} == asdict( C(C(1, 2), C(3, 4)), filter=lambda a, v: a.name != "y", dict_factory=dict_factory, ) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert { "x": 1, "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"], } == asdict(C(1, container([C(2, 3), C(4, 5), "a"]))) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert { "x": 1, "y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]), } == asdict( C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True, ) @given(set_type=st.sampled_from((set, frozenset))) def test_sets_no_retain(self, C, set_type): """ Set types are converted to lists if retain_collection_types=False. """ d = asdict( C(1, set_type((1, 2, 3))), retain_collection_types=False, recurse=True, ) assert {"x": 1, "y": [1, 2, 3]} == d @given(st.sampled_from(MAPPING_TYPES)) def test_dicts(self, C, dict_factory): """ If recurse is True, also recurse into dicts. """ res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory) assert {"x": 1, "y": {"a": {"x": 4, "y": 5}}} == res assert isinstance(res, dict_factory) @given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES)) def test_roundtrip(self, cls, dict_class): """ Test dumping to dicts and back for Hypothesis-generated classes. Private attributes don't round-trip (the attribute name is different than the initializer argument). """ instance = cls() dict_instance = asdict(instance, dict_factory=dict_class) assert isinstance(dict_instance, dict_class) roundtrip_instance = cls(**dict_instance) assert instance == roundtrip_instance @given(simple_classes()) def test_asdict_preserve_order(self, cls): """ Field order should be preserved when dumping to an ordered_dict. """ instance = cls() dict_instance = asdict(instance, dict_factory=dict) assert [a.name for a in fields(cls)] == list(dict_instance.keys()) def test_retain_keys_are_tuples(self): """ retain_collect_types also retains keys. """ @attr.s class A: a = attr.ib() instance = A({(1,): 1}) assert {"a": {(1,): 1}} == attr.asdict( instance, retain_collection_types=True ) def test_tuple_keys(self): """ If a key is collection type, retain_collection_types is False, the key is serialized as a tuple. See #646 """ @attr.s class A: a = attr.ib() instance = A({(1,): 1}) assert {"a": {(1,): 1}} == attr.asdict(instance) def test_named_tuple_retain_type(self): """ Namedtuples can be serialized if retain_collection_types is True. See #1164 """ class Coordinates(NamedTuple): lat: float lon: float @attr.s class A: coords: Coordinates = attr.ib() instance = A(Coordinates(50.419019, 30.516225)) assert {"coords": Coordinates(50.419019, 30.516225)} == attr.asdict( instance, retain_collection_types=True ) def test_type_error_with_retain_type(self): """ Serialization that fails with TypeError leaves the error through if they're not tuples. See #1164 """ message = "__new__() missing 1 required positional argument (asdict)" class Coordinates(list): def __init__(self, first, *rest): if isinstance(first, list): raise TypeError(message) super().__init__([first, *rest]) @attr.s class A: coords: Coordinates = attr.ib() instance = A(Coordinates(50.419019, 30.516225)) with pytest.raises(TypeError, match=re.escape(message)): attr.asdict(instance, retain_collection_types=True) class TestAsTuple: """ Tests for `astuple`. """ @given(st.sampled_from(SEQUENCE_TYPES)) def test_shallow(self, C, tuple_factory): """ Shallow astuple returns correct dict. """ assert tuple_factory([1, 2]) == astuple( C(x=1, y=2), False, tuple_factory=tuple_factory ) @given(st.sampled_from(SEQUENCE_TYPES)) def test_recurse(self, C, tuple_factory): """ Deep astuple returns correct tuple. """ assert tuple_factory( [tuple_factory([1, 2]), tuple_factory([3, 4])] ) == astuple(C(C(1, 2), C(3, 4)), tuple_factory=tuple_factory) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) def test_recurse_property(self, cls, tuple_class): """ Property tests for recursive astuple. """ obj = cls() obj_tuple = astuple(obj, tuple_factory=tuple_class) def assert_proper_tuple_class(obj, obj_tuple): assert isinstance(obj_tuple, tuple_class) for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_tuple_class(field_val, obj_tuple[index]) assert_proper_tuple_class(obj, obj_tuple) @given(nested_classes, st.sampled_from(SEQUENCE_TYPES)) def test_recurse_retain(self, cls, tuple_class): """ Property tests for asserting collection types are retained. """ obj = cls() obj_tuple = astuple( obj, tuple_factory=tuple_class, retain_collection_types=True ) def assert_proper_col_class(obj, obj_tuple): # Iterate over all attributes, and if they are lists or mappings # in the original, assert they are the same class in the dumped. for index, field in enumerate(fields(obj.__class__)): field_val = getattr(obj, field.name) if has(field_val.__class__): # This field holds a class, recurse the assertions. assert_proper_col_class(field_val, obj_tuple[index]) elif isinstance(field_val, (list, tuple)): # This field holds a sequence of something. expected_type = type(obj_tuple[index]) assert type(field_val) is expected_type for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]): if has(obj_e.__class__): assert_proper_col_class(obj_e, obj_tuple_e) elif isinstance(field_val, dict): orig = field_val tupled = obj_tuple[index] assert type(orig) is type(tupled) for obj_e, obj_tuple_e in zip( orig.items(), tupled.items() ): if has(obj_e[0].__class__): # Dict key assert_proper_col_class(obj_e[0], obj_tuple_e[0]) if has(obj_e[1].__class__): # Dict value assert_proper_col_class(obj_e[1], obj_tuple_e[1]) assert_proper_col_class(obj, obj_tuple) @given(st.sampled_from(SEQUENCE_TYPES)) def test_filter(self, C, tuple_factory): """ Attributes that are supposed to be skipped are skipped. """ assert tuple_factory([tuple_factory([1])]) == astuple( C(C(1, 2), C(3, 4)), filter=lambda a, v: a.name != "y", tuple_factory=tuple_factory, ) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples(self, container, C): """ If recurse is True, also recurse into lists. """ assert (1, [(2, 3), (4, 5), "a"]) == astuple( C(1, container([C(2, 3), C(4, 5), "a"])) ) @given(st.sampled_from(SEQUENCE_TYPES)) def test_dicts(self, C, tuple_factory): """ If recurse is True, also recurse into dicts. """ res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory) assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res assert isinstance(res, tuple_factory) @given(container=st.sampled_from(SEQUENCE_TYPES)) def test_lists_tuples_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert (1, container([(2, 3), (4, 5), "a"])) == astuple( C(1, container([C(2, 3), C(4, 5), "a"])), retain_collection_types=True, ) @given(container=st.sampled_from(MAPPING_TYPES)) def test_dicts_retain_type(self, container, C): """ If recurse and retain_collection_types are True, also recurse into lists and do not convert them into list. """ assert (1, container({"a": (4, 5)})) == astuple( C(1, container({"a": C(4, 5)})), retain_collection_types=True ) @given(simple_classes(), st.sampled_from(SEQUENCE_TYPES)) def test_roundtrip(self, cls, tuple_class): """ Test dumping to tuple and back for Hypothesis-generated classes. """ instance = cls() tuple_instance = astuple(instance, tuple_factory=tuple_class) assert isinstance(tuple_instance, tuple_class) roundtrip_instance = cls(*tuple_instance) assert instance == roundtrip_instance @given(set_type=st.sampled_from((set, frozenset))) def test_sets_no_retain(self, C, set_type): """ Set types are converted to lists if retain_collection_types=False. """ d = astuple( C(1, set_type((1, 2, 3))), retain_collection_types=False, recurse=True, ) assert (1, [1, 2, 3]) == d def test_named_tuple_retain_type(self): """ Namedtuples can be serialized if retain_collection_types is True. See #1164 """ class Coordinates(NamedTuple): lat: float lon: float @attr.s class A: coords: Coordinates = attr.ib() instance = A(Coordinates(50.419019, 30.516225)) assert (Coordinates(50.419019, 30.516225),) == attr.astuple( instance, retain_collection_types=True ) def test_type_error_with_retain_type(self): """ Serialization that fails with TypeError leaves the error through if they're not tuples. See #1164 """ message = "__new__() missing 1 required positional argument (astuple)" class Coordinates(list): def __init__(self, first, *rest): if isinstance(first, list): raise TypeError(message) super().__init__([first, *rest]) @attr.s class A: coords: Coordinates = attr.ib() instance = A(Coordinates(50.419019, 30.516225)) with pytest.raises(TypeError, match=re.escape(message)): attr.astuple(instance, retain_collection_types=True) class TestHas: """ Tests for `has`. """ def test_positive(self, C): """ Returns `True` on decorated classes. """ assert has(C) def test_positive_empty(self): """ Returns `True` on decorated classes even if there are no attributes. """ @attr.s class D: pass assert has(D) def test_negative(self): """ Returns `False` on non-decorated classes. """ assert not has(object) def test_generics(self): """ Works with generic classes. """ T = TypeVar("T") @attr.define class A(Generic[T]): a: T assert has(A) assert has(A[str]) # Verify twice, since there's caching going on. assert has(A[str]) def test_generics_negative(self): """ Returns `False` on non-decorated generic classes. """ T = TypeVar("T") class A(Generic[T]): a: T assert not has(A) assert not has(A[str]) # Verify twice, since there's caching going on. assert not has(A[str]) class TestAssoc: """ Tests for `assoc`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attr.s(slots=slots, frozen=frozen) class C: pass i1 = C() i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() i2 = assoc(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) change_dict = {name: data.draw(st.integers()) for name in chosen_names} changed = assoc(original, **change_dict) for k, v in change_dict.items(): assert getattr(changed, k) == v @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(AttrsAttributeNotFoundError) as e: assoc(C(), aaaa=2) assert (f"aaaa is not an attrs attribute on {C!r}.",) == e.value.args def test_frozen(self): """ Works on frozen classes. """ @attr.s(frozen=True) class C: x = attr.ib() y = attr.ib() assert C(3, 2) == assoc(C(1, 2), x=3) class TestEvolve: """ Tests for `evolve`. """ @given(slots=st.booleans(), frozen=st.booleans()) def test_empty(self, slots, frozen): """ Empty classes without changes get copied. """ @attr.s(slots=slots, frozen=frozen) class C: pass i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes()) def test_no_changes(self, C): """ No changes means a verbatim copy. """ i1 = C() i2 = evolve(i1) assert i1 is not i2 assert i1 == i2 @given(simple_classes(), st.data()) def test_change(self, C, data): """ Changes work. """ # Take the first attribute, and change it. assume(fields(C)) # Skip classes with no attributes. field_names = [a.name for a in fields(C)] original = C() chosen_names = data.draw(st.sets(st.sampled_from(field_names))) # We pay special attention to private attributes, they should behave # like in `__init__`. change_dict = { name.replace("_", ""): data.draw(st.integers()) for name in chosen_names } changed = evolve(original, **change_dict) for name in chosen_names: assert getattr(changed, name) == change_dict[name.replace("_", "")] @given(simple_classes()) def test_unknown(self, C): """ Wanting to change an unknown attribute raises an AttrsAttributeNotFoundError. """ # No generated class will have a four letter attribute. with pytest.raises(TypeError) as e: evolve(C(), aaaa=2) if hasattr(C, "__attrs_init__"): expected = ( "__attrs_init__() got an unexpected keyword argument 'aaaa'" ) else: expected = "__init__() got an unexpected keyword argument 'aaaa'" assert e.value.args[0].endswith(expected) def test_validator_failure(self): """ TypeError isn't swallowed when validation fails within evolve. """ @attr.s class C: a = attr.ib(validator=instance_of(int)) with pytest.raises(TypeError) as e: evolve(C(a=1), a="some string") m = e.value.args[0] assert m.startswith("'a' must be ") def test_private(self): """ evolve() acts as `__init__` with regards to private attributes. """ @attr.s class C: _a = attr.ib() assert evolve(C(1), a=2)._a == 2 with pytest.raises(TypeError): evolve(C(1), _a=2) with pytest.raises(TypeError): evolve(C(1), a=3, _a=2) def test_non_init_attrs(self): """ evolve() handles `init=False` attributes. """ @attr.s class C: a = attr.ib() b = attr.ib(init=False, default=0) assert evolve(C(1), a=2).a == 2 def test_regression_attrs_classes(self): """ evolve() can evolve fields that are instances of attrs classes. Regression test for #804 """ @attr.s class Cls1: param1 = attr.ib() @attr.s class Cls2: param2 = attr.ib() obj2a = Cls2(param2="a") obj2b = Cls2(param2="b") obj1a = Cls1(param1=obj2a) assert Cls1(param1=Cls2(param2="b")) == attr.evolve( obj1a, param1=obj2b ) def test_dicts(self): """ evolve() can replace an attrs class instance with a dict. See #806 """ @attr.s class Cls1: param1 = attr.ib() @attr.s class Cls2: param2 = attr.ib() obj2a = Cls2(param2="a") obj2b = {"foo": 42, "param2": 42} obj1a = Cls1(param1=obj2a) assert Cls1({"foo": 42, "param2": 42}) == attr.evolve( obj1a, param1=obj2b ) def test_no_inst(self): """ Missing inst argument raises a TypeError like Python would. """ with pytest.raises( TypeError, match=r"evolve\(\) takes 1 positional argument" ): evolve(x=1) def test_too_many_pos_args(self): """ More than one positional argument raises a TypeError like Python would. """ with pytest.raises( TypeError, match=r"evolve\(\) takes 1 positional argument, but 2 were given", ): evolve(1, 2) def test_can_change_inst(self): """ If the instance is passed by positional argument, a field named `inst` can be changed. """ @attr.define class C: inst: int assert C(42) == evolve(C(23), inst=42)