diff --git a/CHANGELOG.md b/CHANGELOG.md index cbbdd0ae74..83345d41fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer(precision=64)` during evaluation which now uses the wrapped precision module ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983)) - Fixed an issue to use wrapped `LightningModule` for evaluation during `trainer.fit` for `BaguaStrategy` ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983)) - Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028)) +- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897)) ## [1.6.3] - 2022-05-03 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b3d2adec57..bf86471fe9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -19,6 +19,7 @@ import logging import numbers import os import tempfile +import weakref from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union @@ -45,6 +46,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_ from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args @@ -2065,4 +2067,9 @@ class LightningModule( from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook self._register_state_dict_hook(state_dict_hook) - self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) + + if _TORCH_GREATER_EQUAL_1_12: + self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) + else: + # We need to make sure the self inside the method is a weakref proxy + self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 835e56f181..3647dbedd1 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -94,6 +94,7 @@ _TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1") _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") _TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2") _TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0") +_TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use_base_version=True) _APEX_AVAILABLE = _module_available("apex.amp") _BAGUA_AVAILABLE = _package_available("bagua") diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 07fcf8dadc..180e7c46fe 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from unittest.mock import Mock import pytest @@ -399,3 +400,10 @@ def test_lightning_module_configure_gradient_clipping_different_argument_values( match=r"gradient_clip_algorithm='norm'\)` and have passed `clip_gradients\(gradient_clip_algorithm='foo'", ): trainer.fit(model) + + +def test_proper_refcount(): + torch_module = nn.Module() + lightning_module = LightningModule() + + assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)