Fix number of references to LightningModule (#12897)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
otaj 2022-05-13 17:23:25 +02:00 committed by lexierule
parent f0aca4c312
commit b8e7d6d614
4 changed files with 18 additions and 1 deletions

View File

@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer(precision=64)` during evaluation which now uses the wrapped precision module ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983))
- Fixed an issue to use wrapped `LightningModule` for evaluation during `trainer.fit` for `BaguaStrategy` ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983))
- Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028))
- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))
## [1.6.3] - 2022-05-03

View File

@ -19,6 +19,7 @@ import logging
import numbers
import os
import tempfile
import weakref
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
@ -45,6 +46,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
@ -2065,4 +2067,9 @@ class LightningModule(
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
self._register_state_dict_hook(state_dict_hook)
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
if _TORCH_GREATER_EQUAL_1_12:
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
else:
# We need to make sure the self inside the method is a weakref proxy
self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True)

View File

@ -94,6 +94,7 @@ _TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
_TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use_base_version=True)
_APEX_AVAILABLE = _module_available("apex.amp")
_BAGUA_AVAILABLE = _package_available("bagua")

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import Mock
import pytest
@ -399,3 +400,10 @@ def test_lightning_module_configure_gradient_clipping_different_argument_values(
match=r"gradient_clip_algorithm='norm'\)` and have passed `clip_gradients\(gradient_clip_algorithm='foo'",
):
trainer.fit(model)
def test_proper_refcount():
torch_module = nn.Module()
lightning_module = LightningModule()
assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)