Fault Tolerant Manual: Add is_obj_stateful utility (#10646)

This commit is contained in:
thomas chaton 2021-11-22 18:48:32 +00:00 committed by GitHub
parent cd7b4342f6
commit 6acfef680f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 0 deletions

View File

@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))

View File

@ -22,6 +22,7 @@ import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
from typing_extensions import Protocol, runtime_checkable
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
@ -570,3 +571,14 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A
else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
def state_dict(self) -> Dict[str, Any]:
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...

View File

@ -40,6 +40,7 @@ from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_SupportsStateDict,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
@ -1195,6 +1196,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict
def test_supports_state_dict_protocol():
class StatefulClass:
def state_dict(self):
pass
def load_state_dict(self, state_dict):
pass
assert isinstance(StatefulClass(), _SupportsStateDict)
class NotStatefulClass:
def state_dict(self):
pass
assert not isinstance(NotStatefulClass(), _SupportsStateDict)
class NotStateful2Class:
def load_state_dict(self, state_dict):
pass
assert not isinstance(NotStateful2Class(), _SupportsStateDict)
def test_fault_tolerant_mode_enum():
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()