Fault Tolerant Manual: Add is_obj_stateful utility (#10646)
This commit is contained in:
parent
cd7b4342f6
commit
6acfef680f
|
@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
- Fault Tolerant Manual
|
||||
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
|
||||
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import Dataset, get_worker_info, Sampler
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
||||
|
@ -570,3 +571,14 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A
|
|||
|
||||
else:
|
||||
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _SupportsStateDict(Protocol):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
...
|
||||
|
|
|
@ -40,6 +40,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
_add_capture_metadata_collate,
|
||||
_dataloader_load_state_dict,
|
||||
_dataloader_to_state_dict,
|
||||
_SupportsStateDict,
|
||||
CaptureIterableDataset,
|
||||
CaptureMapDataset,
|
||||
FastForwardSampler,
|
||||
|
@ -1195,6 +1196,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
|
|||
assert "dataloader_state_dict" in state_dict
|
||||
|
||||
|
||||
def test_supports_state_dict_protocol():
|
||||
class StatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert isinstance(StatefulClass(), _SupportsStateDict)
|
||||
|
||||
class NotStatefulClass:
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStatefulClass(), _SupportsStateDict)
|
||||
|
||||
class NotStateful2Class:
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
assert not isinstance(NotStateful2Class(), _SupportsStateDict)
|
||||
|
||||
|
||||
def test_fault_tolerant_mode_enum():
|
||||
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
|
||||
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()
|
||||
|
|
Loading…
Reference in New Issue