Fix number of references to LightningModule (#12897)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
f0aca4c312
commit
b8e7d6d614
|
@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `Trainer(precision=64)` during evaluation which now uses the wrapped precision module ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983))
|
||||
- Fixed an issue to use wrapped `LightningModule` for evaluation during `trainer.fit` for `BaguaStrategy` ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983))
|
||||
- Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028))
|
||||
- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))
|
||||
|
||||
|
||||
## [1.6.3] - 2022-05-03
|
||||
|
|
|
@ -19,6 +19,7 @@ import logging
|
|||
import numbers
|
||||
import os
|
||||
import tempfile
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
|
||||
|
@ -45,6 +46,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_
|
|||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.memory import get_model_size_mb
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
||||
from pytorch_lightning.utilities.parsing import collect_init_args
|
||||
|
@ -2065,4 +2067,9 @@ class LightningModule(
|
|||
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
|
||||
|
||||
self._register_state_dict_hook(state_dict_hook)
|
||||
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
||||
else:
|
||||
# We need to make sure the self inside the method is a weakref proxy
|
||||
self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True)
|
||||
|
|
|
@ -94,6 +94,7 @@ _TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")
|
|||
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
|
||||
_TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2")
|
||||
_TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0")
|
||||
_TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use_base_version=True)
|
||||
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BAGUA_AVAILABLE = _package_available("bagua")
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
@ -399,3 +400,10 @@ def test_lightning_module_configure_gradient_clipping_different_argument_values(
|
|||
match=r"gradient_clip_algorithm='norm'\)` and have passed `clip_gradients\(gradient_clip_algorithm='foo'",
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_proper_refcount():
|
||||
torch_module = nn.Module()
|
||||
lightning_module = LightningModule()
|
||||
|
||||
assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)
|
||||
|
|
Loading…
Reference in New Issue