Rename `_SupportsStateDict` --> `_Stateful` Protocol (#11469)
This commit is contained in:
parent
b8e360dafa
commit
ec1379da2c
|
@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
- Fault Tolerant Manual
|
||||
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646))
|
||||
* Add `_Stateful` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646))
|
||||
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/pull/10645))
|
||||
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/pull/10647))
|
||||
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/pull/10674))
|
||||
|
|
|
@ -24,12 +24,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_SupportsStateDict,
|
||||
LRSchedulerConfig,
|
||||
LRSchedulerTypeTuple,
|
||||
ReduceLROnPlateau,
|
||||
)
|
||||
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
|
||||
|
||||
|
||||
def do_nothing_closure() -> None:
|
||||
|
@ -338,7 +333,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig
|
|||
def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
|
||||
for config in lr_scheduler_configs:
|
||||
scheduler = config.scheduler
|
||||
if not isinstance(scheduler, _SupportsStateDict):
|
||||
if not isinstance(scheduler, _Stateful):
|
||||
raise TypeError(
|
||||
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
|
||||
" It should have `state_dict` and `load_state_dict` methods defined."
|
||||
|
|
|
@ -42,7 +42,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|||
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _SupportsStateDict
|
||||
from pytorch_lightning.utilities.types import _Stateful
|
||||
|
||||
|
||||
class FastForwardSampler(Sampler):
|
||||
|
@ -581,17 +581,17 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic
|
|||
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
|
||||
if sampler_state:
|
||||
# `sampler_state` keys contain all the DataLoader attribute names
|
||||
# which matched `_SupportsStateDict` API interface while collecting the `state_dict`.
|
||||
# which matched `_Stateful` API interface while collecting the `state_dict`.
|
||||
for dataloader_attr_name in sampler_state:
|
||||
obj = getattr(dataloader, dataloader_attr_name)
|
||||
if not isinstance(obj, _SupportsStateDict):
|
||||
if not isinstance(obj, _Stateful):
|
||||
raise MisconfigurationException(
|
||||
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
|
||||
)
|
||||
|
||||
obj.load_state_dict(sampler_state[dataloader_attr_name])
|
||||
|
||||
if not isinstance(dataloader.dataset, _SupportsStateDict):
|
||||
if not isinstance(dataloader.dataset, _Stateful):
|
||||
return
|
||||
|
||||
dataset_state = {
|
||||
|
@ -645,9 +645,7 @@ class _StatefulDataLoaderIter:
|
|||
def _store_sampler_state(self) -> None:
|
||||
"""This function is used to extract the sampler states if any."""
|
||||
sampler_state = {
|
||||
k: v.state_dict()
|
||||
for k, v in self._loader.__dict__.items()
|
||||
if isinstance(v, _SupportsStateDict) and k != "dataset"
|
||||
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
|
||||
}
|
||||
self.__accumulate_state(sampler_state)
|
||||
|
||||
|
|
|
@ -49,8 +49,8 @@ _DEVICE = Union[torch.device, str, int]
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
class _SupportsStateDict(Protocol):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
|
||||
class _Stateful(Protocol):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
@ -62,7 +62,7 @@ class _SupportsStateDict(Protocol):
|
|||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class _LRScheduler(_SupportsStateDict, Protocol):
|
||||
class _LRScheduler(_Stateful, Protocol):
|
||||
optimizer: Optimizer
|
||||
|
||||
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
|
||||
|
@ -72,7 +72,7 @@ class _LRScheduler(_SupportsStateDict, Protocol):
|
|||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class ReduceLROnPlateau(_SupportsStateDict, Protocol):
|
||||
class ReduceLROnPlateau(_Stateful, Protocol):
|
||||
in_cooldown: bool
|
||||
optimizer: Optimizer
|
||||
|
||||
|
|
|
@ -59,7 +59,6 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBat
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.types import _SupportsStateDict
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -1296,29 +1295,6 @@ def test_rotate_worker_indices():
|
|||
_rotate_worker_indices(state_dict, 2, 3)
|
||||
|
||||
|
||||
def test_supports_state_dict_protocol():
|
||||
class StatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert isinstance(StatefulClass(), _SupportsStateDict)
|
||||
|
||||
class NotStatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStatefulClass(), _SupportsStateDict)
|
||||
|
||||
class NotStateful2Class:
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStateful2Class(), _SupportsStateDict)
|
||||
|
||||
|
||||
def test_fault_tolerant_mode_enum():
|
||||
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
|
||||
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.utilities.types import _Stateful
|
||||
|
||||
|
||||
def test_stateful_protocol():
|
||||
class StatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert isinstance(StatefulClass(), _Stateful)
|
||||
|
||||
class NotStatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStatefulClass(), _Stateful)
|
||||
|
||||
class NotStateful2Class:
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStateful2Class(), _Stateful)
|
Loading…
Reference in New Issue