diff --git a/changelog.d/1221.change.md b/changelog.d/1221.change.md new file mode 100644 index 00000000..742b2f44 --- /dev/null +++ b/changelog.d/1221.change.md @@ -0,0 +1,2 @@ +Allow original slotted cached_property classes to be cleaned by GC. +Allow super calls in slotted cached properties. diff --git a/src/attr/_make.py b/src/attr/_make.py index 10b4eca7..3e0e203d 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -318,11 +318,11 @@ def _compile_and_eval(script, globs, locs=None, filename=""): eval(bytecode, globs, locs) -def _make_method(name, script, filename, globs): +def _make_method(name, script, filename, globs, locals=None): """ Create the method with the script given and return the method object. """ - locs = {} + locs = {} if locals is None else locals # In order of debuggers like PDB being able to step through the code, # we add a fake linecache entry. @@ -608,7 +608,7 @@ def _make_cached_property_getattr( lines = [ # Wrapped to get `__class__` into closure cell for super() # (It will be replaced with the newly constructed class after construction). - "def wrapper():", + "def wrapper(_cls):", " __class__ = _cls", " def __getattr__(self, item, cached_properties=cached_properties, original_getattr=original_getattr, _cached_setattr_get=_cached_setattr_get):", " func = cached_properties.get(item)", @@ -635,7 +635,7 @@ def _make_cached_property_getattr( lines.extend( [ " return __getattr__", - "__getattr__ = wrapper()", + "__getattr__ = wrapper(_cls)", ] ) @@ -644,7 +644,6 @@ def _make_cached_property_getattr( glob = { "cached_properties": cached_properties, "_cached_setattr_get": _obj_setattr.__get__, - "_cls": cls, "original_getattr": original_getattr, } @@ -653,6 +652,9 @@ def _make_cached_property_getattr( "\n".join(lines), unique_filename, glob, + locals={ + "_cls": cls, + }, ) @@ -938,6 +940,10 @@ class _ClassBuilder: # Clear out function from class to avoid clashing. del cd[name] + additional_closure_functions_to_update.extend( + cached_properties.values() + ) + class_annotations = _get_annotations(self._cls) for name, func in cached_properties.items(): annotation = inspect.signature(func).return_annotation diff --git a/tests/test_make.py b/tests/test_make.py index 19f7a4cd..b67453d5 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -3,8 +3,6 @@ """ Tests for `attr._make`. """ - - import copy import functools import gc @@ -23,7 +21,7 @@ from hypothesis.strategies import booleans, integers, lists, sampled_from, text import attr from attr import _config -from attr._compat import PY310 +from attr._compat import PY310, PY_3_8_PLUS from attr._make import ( Attribute, Factory, @@ -1773,6 +1771,27 @@ class TestClassBuilder: assert [C2] == C.__subclasses__() + @pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") + def test_no_references_to_original_when_using_cached_property(self): + """ + When subclassing a slotted class and using cached property, there are no stray references to the original class. + """ + + @attr.s(slots=True) + class C: + pass + + @attr.s(slots=True) + class C2(C): + @functools.cached_property + def value(self) -> int: + return 0 + + # The original C2 is in a reference cycle, so force a collect: + gc.collect() + + assert [C2] == C.__subclasses__() + def _get_copy_kwargs(include_slots=True): """ Generate a list of compatible attr.s arguments for the `copy` tests. diff --git a/tests/test_slots.py b/tests/test_slots.py index 26365ab0..c1332f2d 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -943,6 +943,25 @@ def test_slots_with_multiple_cached_property_subclasses_works(): assert ab.h == "h" +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slotted_cached_property_can_access_super(): + """ + Multiple sub-classes shouldn't break cached properties. + """ + + @attr.s(slots=True) + class A: + x = attr.ib(kw_only=True) + + @attr.s(slots=True) + class B(A): + @functools.cached_property + def f(self): + return super().x * 2 + + assert B(x=1).f == 2 + + @pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") def test_slots_sub_class_avoids_duplicated_slots(): """