Slotted cached property reference (#1221)

* Remove undesirable original class reference when using slotted cached_property

* Update closure cells for slotted cached property functions

* Changelog

* fixup py 3.7 test
This commit is contained in:
diabolo-dan 2024-01-04 10:52:49 +00:00 committed by GitHub
parent 5b0a4e6ab8
commit 26d8dd7957
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 8 deletions

View File

@ -0,0 +1,2 @@
Allow original slotted cached_property classes to be cleaned by GC.
Allow super calls in slotted cached properties.

View File

@ -318,11 +318,11 @@ def _compile_and_eval(script, globs, locs=None, filename=""):
eval(bytecode, globs, locs) eval(bytecode, globs, locs)
def _make_method(name, script, filename, globs): def _make_method(name, script, filename, globs, locals=None):
""" """
Create the method with the script given and return the method object. Create the method with the script given and return the method object.
""" """
locs = {} locs = {} if locals is None else locals
# In order of debuggers like PDB being able to step through the code, # In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry. # we add a fake linecache entry.
@ -608,7 +608,7 @@ def _make_cached_property_getattr(
lines = [ lines = [
# Wrapped to get `__class__` into closure cell for super() # Wrapped to get `__class__` into closure cell for super()
# (It will be replaced with the newly constructed class after construction). # (It will be replaced with the newly constructed class after construction).
"def wrapper():", "def wrapper(_cls):",
" __class__ = _cls", " __class__ = _cls",
" def __getattr__(self, item, cached_properties=cached_properties, original_getattr=original_getattr, _cached_setattr_get=_cached_setattr_get):", " def __getattr__(self, item, cached_properties=cached_properties, original_getattr=original_getattr, _cached_setattr_get=_cached_setattr_get):",
" func = cached_properties.get(item)", " func = cached_properties.get(item)",
@ -635,7 +635,7 @@ def _make_cached_property_getattr(
lines.extend( lines.extend(
[ [
" return __getattr__", " return __getattr__",
"__getattr__ = wrapper()", "__getattr__ = wrapper(_cls)",
] ]
) )
@ -644,7 +644,6 @@ def _make_cached_property_getattr(
glob = { glob = {
"cached_properties": cached_properties, "cached_properties": cached_properties,
"_cached_setattr_get": _obj_setattr.__get__, "_cached_setattr_get": _obj_setattr.__get__,
"_cls": cls,
"original_getattr": original_getattr, "original_getattr": original_getattr,
} }
@ -653,6 +652,9 @@ def _make_cached_property_getattr(
"\n".join(lines), "\n".join(lines),
unique_filename, unique_filename,
glob, glob,
locals={
"_cls": cls,
},
) )
@ -938,6 +940,10 @@ class _ClassBuilder:
# Clear out function from class to avoid clashing. # Clear out function from class to avoid clashing.
del cd[name] del cd[name]
additional_closure_functions_to_update.extend(
cached_properties.values()
)
class_annotations = _get_annotations(self._cls) class_annotations = _get_annotations(self._cls)
for name, func in cached_properties.items(): for name, func in cached_properties.items():
annotation = inspect.signature(func).return_annotation annotation = inspect.signature(func).return_annotation

View File

@ -3,8 +3,6 @@
""" """
Tests for `attr._make`. Tests for `attr._make`.
""" """
import copy import copy
import functools import functools
import gc import gc
@ -23,7 +21,7 @@ from hypothesis.strategies import booleans, integers, lists, sampled_from, text
import attr import attr
from attr import _config from attr import _config
from attr._compat import PY310 from attr._compat import PY310, PY_3_8_PLUS
from attr._make import ( from attr._make import (
Attribute, Attribute,
Factory, Factory,
@ -1773,6 +1771,27 @@ class TestClassBuilder:
assert [C2] == C.__subclasses__() assert [C2] == C.__subclasses__()
@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_no_references_to_original_when_using_cached_property(self):
"""
When subclassing a slotted class and using cached property, there are no stray references to the original class.
"""
@attr.s(slots=True)
class C:
pass
@attr.s(slots=True)
class C2(C):
@functools.cached_property
def value(self) -> int:
return 0
# The original C2 is in a reference cycle, so force a collect:
gc.collect()
assert [C2] == C.__subclasses__()
def _get_copy_kwargs(include_slots=True): def _get_copy_kwargs(include_slots=True):
""" """
Generate a list of compatible attr.s arguments for the `copy` tests. Generate a list of compatible attr.s arguments for the `copy` tests.

View File

@ -943,6 +943,25 @@ def test_slots_with_multiple_cached_property_subclasses_works():
assert ab.h == "h" assert ab.h == "h"
@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slotted_cached_property_can_access_super():
"""
Multiple sub-classes shouldn't break cached properties.
"""
@attr.s(slots=True)
class A:
x = attr.ib(kw_only=True)
@attr.s(slots=True)
class B(A):
@functools.cached_property
def f(self):
return super().x * 2
assert B(x=1).f == 2
@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") @pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_sub_class_avoids_duplicated_slots(): def test_slots_sub_class_avoids_duplicated_slots():
""" """