From ec1379da2c9388b144adcb4b4cc751839f085f65 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Wed, 2 Feb 2022 14:45:59 -0800 Subject: [PATCH] Rename `_SupportsStateDict` --> `_Stateful` Protocol (#11469) --- CHANGELOG.md | 2 +- pytorch_lightning/core/optimizer.py | 9 ++--- pytorch_lightning/utilities/auto_restart.py | 12 +++---- pytorch_lightning/utilities/types.py | 8 ++--- tests/utilities/test_auto_restart.py | 24 ------------- tests/utilities/test_types.py | 37 +++++++++++++++++++++ 6 files changed, 49 insertions(+), 43 deletions(-) create mode 100644 tests/utilities/test_types.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 521361b6b3..7a2b32ac44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646)) + * Add `_Stateful` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/pull/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/pull/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/pull/10674)) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 8014b21637..3825382b2f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -24,12 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import ( - _SupportsStateDict, - LRSchedulerConfig, - LRSchedulerTypeTuple, - ReduceLROnPlateau, -) +from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau def do_nothing_closure() -> None: @@ -338,7 +333,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: for config in lr_scheduler_configs: scheduler = config.scheduler - if not isinstance(scheduler, _SupportsStateDict): + if not isinstance(scheduler, _Stateful): raise TypeError( f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." " It should have `state_dict` and `load_state_dict` methods defined." diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 7e62ebc5d3..f8ce7177da 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -42,7 +42,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _SupportsStateDict +from pytorch_lightning.utilities.types import _Stateful class FastForwardSampler(Sampler): @@ -581,17 +581,17 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) if sampler_state: # `sampler_state` keys contain all the DataLoader attribute names - # which matched `_SupportsStateDict` API interface while collecting the `state_dict`. + # which matched `_Stateful` API interface while collecting the `state_dict`. for dataloader_attr_name in sampler_state: obj = getattr(dataloader, dataloader_attr_name) - if not isinstance(obj, _SupportsStateDict): + if not isinstance(obj, _Stateful): raise MisconfigurationException( f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method." ) obj.load_state_dict(sampler_state[dataloader_attr_name]) - if not isinstance(dataloader.dataset, _SupportsStateDict): + if not isinstance(dataloader.dataset, _Stateful): return dataset_state = { @@ -645,9 +645,7 @@ class _StatefulDataLoaderIter: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" sampler_state = { - k: v.state_dict() - for k, v in self._loader.__dict__.items() - if isinstance(v, _SupportsStateDict) and k != "dataset" + k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 43cce711ec..c5e384117f 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -49,8 +49,8 @@ _DEVICE = Union[torch.device, str, int] @runtime_checkable -class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" +class _Stateful(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" def state_dict(self) -> Dict[str, Any]: ... @@ -62,7 +62,7 @@ class _SupportsStateDict(Protocol): # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class _LRScheduler(_SupportsStateDict, Protocol): +class _LRScheduler(_Stateful, Protocol): optimizer: Optimizer def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: @@ -72,7 +72,7 @@ class _LRScheduler(_SupportsStateDict, Protocol): # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class ReduceLROnPlateau(_SupportsStateDict, Protocol): +class ReduceLROnPlateau(_Stateful, Protocol): in_cooldown: bool optimizer: Optimizer diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e467436238..ff4db0051d 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -59,7 +59,6 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBat from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.types import _SupportsStateDict from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -1296,29 +1295,6 @@ def test_rotate_worker_indices(): _rotate_worker_indices(state_dict, 2, 3) -def test_supports_state_dict_protocol(): - class StatefulClass: - def state_dict(self): - pass - - def load_state_dict(self, state_dict): - pass - - assert isinstance(StatefulClass(), _SupportsStateDict) - - class NotStatefulClass: - def state_dict(self): - pass - - assert not isinstance(NotStatefulClass(), _SupportsStateDict) - - class NotStateful2Class: - def load_state_dict(self, state_dict): - pass - - assert not isinstance(NotStateful2Class(), _SupportsStateDict) - - def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() diff --git a/tests/utilities/test_types.py b/tests/utilities/test_types.py new file mode 100644 index 0000000000..5b523a43dc --- /dev/null +++ b/tests/utilities/test_types.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities.types import _Stateful + + +def test_stateful_protocol(): + class StatefulClass: + def state_dict(self): + pass + + def load_state_dict(self, state_dict): + pass + + assert isinstance(StatefulClass(), _Stateful) + + class NotStatefulClass: + def state_dict(self): + pass + + assert not isinstance(NotStatefulClass(), _Stateful) + + class NotStateful2Class: + def load_state_dict(self, state_dict): + pass + + assert not isinstance(NotStateful2Class(), _Stateful)