Rename `_SupportsStateDict` --> `_Stateful` Protocol (#11469)

This commit is contained in:
jjenniferdai 2022-02-02 14:45:59 -08:00 committed by GitHub
parent b8e360dafa
commit ec1379da2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 49 additions and 43 deletions

View File

@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646))
* Add `_Stateful` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/pull/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/pull/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/pull/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/pull/10674))

View File

@ -24,12 +24,7 @@ import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import (
_SupportsStateDict,
LRSchedulerConfig,
LRSchedulerTypeTuple,
ReduceLROnPlateau,
)
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
def do_nothing_closure() -> None:
@ -338,7 +333,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig
def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
for config in lr_scheduler_configs:
scheduler = config.scheduler
if not isinstance(scheduler, _SupportsStateDict):
if not isinstance(scheduler, _Stateful):
raise TypeError(
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
" It should have `state_dict` and `load_state_dict` methods defined."

View File

@ -42,7 +42,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _SupportsStateDict
from pytorch_lightning.utilities.types import _Stateful
class FastForwardSampler(Sampler):
@ -581,17 +581,17 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
if sampler_state:
# `sampler_state` keys contain all the DataLoader attribute names
# which matched `_SupportsStateDict` API interface while collecting the `state_dict`.
# which matched `_Stateful` API interface while collecting the `state_dict`.
for dataloader_attr_name in sampler_state:
obj = getattr(dataloader, dataloader_attr_name)
if not isinstance(obj, _SupportsStateDict):
if not isinstance(obj, _Stateful):
raise MisconfigurationException(
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
)
obj.load_state_dict(sampler_state[dataloader_attr_name])
if not isinstance(dataloader.dataset, _SupportsStateDict):
if not isinstance(dataloader.dataset, _Stateful):
return
dataset_state = {
@ -645,9 +645,7 @@ class _StatefulDataLoaderIter:
def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict()
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
}
self.__accumulate_state(sampler_state)

View File

@ -49,8 +49,8 @@ _DEVICE = Union[torch.device, str, int]
@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
class _Stateful(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
def state_dict(self) -> Dict[str, Any]:
...
@ -62,7 +62,7 @@ class _SupportsStateDict(Protocol):
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class _LRScheduler(_SupportsStateDict, Protocol):
class _LRScheduler(_Stateful, Protocol):
optimizer: Optimizer
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
@ -72,7 +72,7 @@ class _LRScheduler(_SupportsStateDict, Protocol):
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class ReduceLROnPlateau(_SupportsStateDict, Protocol):
class ReduceLROnPlateau(_Stateful, Protocol):
in_cooldown: bool
optimizer: Optimizer

View File

@ -59,7 +59,6 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBat
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.types import _SupportsStateDict
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -1296,29 +1295,6 @@ def test_rotate_worker_indices():
_rotate_worker_indices(state_dict, 2, 3)
def test_supports_state_dict_protocol():
class StatefulClass:
def state_dict(self):
pass
def load_state_dict(self, state_dict):
pass
assert isinstance(StatefulClass(), _SupportsStateDict)
class NotStatefulClass:
def state_dict(self):
pass
assert not isinstance(NotStatefulClass(), _SupportsStateDict)
class NotStateful2Class:
def load_state_dict(self, state_dict):
pass
assert not isinstance(NotStateful2Class(), _SupportsStateDict)
def test_fault_tolerant_mode_enum():
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()

View File

@ -0,0 +1,37 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.types import _Stateful
def test_stateful_protocol():
class StatefulClass:
def state_dict(self):
pass
def load_state_dict(self, state_dict):
pass
assert isinstance(StatefulClass(), _Stateful)
class NotStatefulClass:
def state_dict(self):
pass
assert not isinstance(NotStatefulClass(), _Stateful)
class NotStateful2Class:
def load_state_dict(self, state_dict):
pass
assert not isinstance(NotStateful2Class(), _Stateful)