diff --git a/changelog.d/750.change.rst b/changelog.d/750.change.rst new file mode 100644 index 00000000..e79d7cf1 --- /dev/null +++ b/changelog.d/750.change.rst @@ -0,0 +1 @@ +Allow for a ``__attrs_pre_init__()`` method that -- if defined -- will get called at the beginning of the ``attrs``-generated ``__init__()`` method. diff --git a/src/attr/_make.py b/src/attr/_make.py index f817b951..8bc8634d 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -588,6 +588,7 @@ class _ClassBuilder(object): "_cls_dict", "_delete_attribs", "_frozen", + "_has_pre_init", "_has_post_init", "_is_exc", "_on_setattr", @@ -633,6 +634,7 @@ class _ClassBuilder(object): self._frozen = frozen self._weakref_slot = weakref_slot self._cache_hash = cache_hash + self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) self._delete_attribs = not bool(these) self._is_exc = is_exc @@ -889,6 +891,7 @@ class _ClassBuilder(object): _make_init( self._cls, self._attrs, + self._has_pre_init, self._has_post_init, self._frozen, self._slots, @@ -908,6 +911,7 @@ class _ClassBuilder(object): _make_init( self._cls, self._attrs, + self._has_pre_init, self._has_post_init, self._frozen, self._slots, @@ -1177,9 +1181,11 @@ def attrs( behavior `_ for more details. :param bool init: Create a ``__init__`` method that initializes the - ``attrs`` attributes. Leading underscores are stripped for the - argument name. If a ``__attrs_post_init__`` method exists on the - class, it will be called after the class is fully initialized. + ``attrs`` attributes. Leading underscores are stripped for the argument + name. If a ``__attrs_pre_init__`` method exists on the class, it will + be called before the class is initialized. If a ``__attrs_post_init__`` + method exists on the class, it will be called after the class is fully + initialized. If ``init`` is ``False``, an ``__attrs_init__`` method will be injected instead. This allows you to define a custom ``__init__`` @@ -1326,6 +1332,8 @@ def attrs( .. versionadded:: 20.3.0 *field_transformer* .. versionchanged:: 21.1.0 ``init=False`` injects ``__attrs_init__`` + .. versionchanged:: 21.1.0 Support for ``__attrs_pre_init__`` + """ if auto_detect and PY2: raise PythonTooOldError( @@ -1893,6 +1901,7 @@ def _is_slot_attr(a_name, base_attr_map): def _make_init( cls, attrs, + pre_init, post_init, frozen, slots, @@ -1931,6 +1940,7 @@ def _make_init( filtered_attrs, frozen, slots, + pre_init, post_init, cache_hash, base_attr_map, @@ -2071,6 +2081,7 @@ def _attrs_to_init_script( attrs, frozen, slots, + pre_init, post_init, cache_hash, base_attr_map, @@ -2088,6 +2099,9 @@ def _attrs_to_init_script( a cached ``object.__setattr__``. """ lines = [] + if pre_init: + lines.append("self.__attrs_pre_init__()") + if needs_cached_setattr: lines.append( # Circumvent the __setattr__ descriptor to save one lookup per @@ -2789,10 +2803,13 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments): else: raise TypeError("attrs argument must be a dict or a list.") + pre_init = cls_dict.pop("__attrs_pre_init__", None) post_init = cls_dict.pop("__attrs_post_init__", None) user_init = cls_dict.pop("__init__", None) body = {} + if pre_init is not None: + body["__attrs_pre_init__"] = pre_init if post_init is not None: body["__attrs_post_init__"] = post_init if user_init is not None: diff --git a/tests/strategies.py b/tests/strategies.py index fab9716d..70d424af 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -153,9 +153,17 @@ def simple_classes( attr_names = gen_attr_names() cls_dict = dict(zip(attr_names, attrs)) + pre_init_flag = draw(st.booleans()) post_init_flag = draw(st.booleans()) init_flag = draw(st.booleans()) + if pre_init_flag: + + def pre_init(self): + pass + + cls_dict["__attrs_pre_init__"] = pre_init + if post_init_flag: def post_init(self): diff --git a/tests/test_dunders.py b/tests/test_dunders.py index 87aa36f4..a34f8f48 100644 --- a/tests/test_dunders.py +++ b/tests/test_dunders.py @@ -62,6 +62,7 @@ def _add_init(cls, frozen): cls.__init__ = _make_init( cls, cls.__attrs_attrs__, + getattr(cls, "__attrs_pre_init__", False), getattr(cls, "__attrs_post_init__", False), frozen, _is_slot_cls(cls), diff --git a/tests/test_make.py b/tests/test_make.py index 232e770e..4ba413ad 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -628,6 +628,22 @@ class TestAttributes(object): assert C.D.__name__ == "D" assert C.D.__qualname__ == C.__qualname__ + ".D" + @pytest.mark.parametrize("with_validation", [True, False]) + def test_pre_init(self, with_validation, monkeypatch): + """ + Verify that __attrs_pre_init__ gets called if defined. + """ + monkeypatch.setattr(_config, "_run_validators", with_validation) + + @attr.s + class C(object): + def __attrs_pre_init__(self2): + self2.z = 30 + + c = C() + + assert 30 == getattr(c, "z", None) + @pytest.mark.parametrize("with_validation", [True, False]) def test_post_init(self, with_validation, monkeypatch): """ @@ -647,6 +663,27 @@ class TestAttributes(object): assert 30 == getattr(c, "z", None) + @pytest.mark.parametrize("with_validation", [True, False]) + def test_pre_post_init_order(self, with_validation, monkeypatch): + """ + Verify that __attrs_post_init__ gets called if defined. + """ + monkeypatch.setattr(_config, "_run_validators", with_validation) + + @attr.s + class C(object): + x = attr.ib() + + def __attrs_pre_init__(self2): + self2.z = 30 + + def __attrs_post_init__(self2): + self2.z += self2.x + + c = C(x=10) + + assert 40 == getattr(c, "z", None) + def test_types(self): """ Sets the `Attribute.type` attr from type argument.