Introduce `Stateful` DataModule (#11637)
This commit is contained in:
parent
43a89eb132
commit
1203094a20
|
@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""LightningDataModule for loading DataLoaders with ease."""
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
|
@ -246,3 +246,19 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
if test_dataset is not None:
|
||||
datamodule.test_dataloader = test_dataloader
|
||||
return datamodule
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""Called when saving a checkpoint, implement to generate and save datamodule state.
|
||||
|
||||
Returns:
|
||||
A dictionary containing datamodule state.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
|
||||
|
||||
Args:
|
||||
state_dict: the datamodule state returned by ``state_dict``.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -157,6 +157,8 @@ class CheckpointConnector:
|
|||
datamodule = self.trainer.datamodule
|
||||
if datamodule is not None:
|
||||
datamodule.on_load_checkpoint(self._loaded_checkpoint)
|
||||
if datamodule.__class__.__qualname__ in self._loaded_checkpoint:
|
||||
datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__qualname__])
|
||||
|
||||
def restore_model(self) -> None:
|
||||
"""Restores a model's weights from a PyTorch Lightning checkpoint.
|
||||
|
@ -324,7 +326,7 @@ class CheckpointConnector:
|
|||
CHECKPOINT_HYPER_PARAMS_KEY:
|
||||
CHECKPOINT_HYPER_PARAMS_TYPE:
|
||||
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
|
||||
LightningDataModule.__class__.__name__: pl DataModule's state
|
||||
LightningDataModule.__class__.__qualname__: pl DataModule's state
|
||||
}
|
||||
"""
|
||||
|
||||
|
@ -378,10 +380,17 @@ class CheckpointConnector:
|
|||
else:
|
||||
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
|
||||
|
||||
# give the model a chance to dump a few things
|
||||
# dump stateful datamodule
|
||||
datamodule = self.trainer.datamodule
|
||||
if datamodule is not None:
|
||||
datamodule_state_dict = datamodule.state_dict()
|
||||
if datamodule_state_dict:
|
||||
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict
|
||||
|
||||
# on_save_checkpoint hooks
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_save_checkpoint(checkpoint)
|
||||
if datamodule is not None:
|
||||
datamodule.on_save_checkpoint(checkpoint)
|
||||
|
||||
# TODO: remove this in v1.8.
|
||||
environment = self.trainer._accelerator_connector.cluster_environment
|
||||
|
|
|
@ -196,11 +196,18 @@ def test_dm_checkpoint_save_and_load(tmpdir):
|
|||
return out
|
||||
|
||||
class CustomBoringDataModule(BoringDataModule):
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return {"my": "state_dict"}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self.my_state_dict = state_dict
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
checkpoint[self.__class__.__name__] = self.__class__.__name__
|
||||
checkpoint[self.__class__.__qualname__].update({"on_save": "update"})
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
|
||||
self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy()
|
||||
checkpoint[self.__class__.__qualname__].pop("on_save")
|
||||
|
||||
reset_seed()
|
||||
dm = CustomBoringDataModule()
|
||||
|
@ -220,14 +227,14 @@ def test_dm_checkpoint_save_and_load(tmpdir):
|
|||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
assert dm.__class__.__name__ in checkpoint
|
||||
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
|
||||
assert dm.__class__.__qualname__ in checkpoint
|
||||
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"}
|
||||
|
||||
for trainer_fn in TrainerFn:
|
||||
trainer.state.fn = trainer_fn
|
||||
with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
|
||||
trainer._restore_modules_and_callbacks(checkpoint_path)
|
||||
dm_mock.assert_called_once()
|
||||
trainer._restore_modules_and_callbacks(checkpoint_path)
|
||||
assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
|
||||
assert dm.my_state_dict == {"my": "state_dict"}
|
||||
|
||||
|
||||
def test_full_loop(tmpdir):
|
||||
|
|
|
@ -865,6 +865,7 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="val_dataloader"),
|
||||
dict(name="train_dataloader"),
|
||||
dict(name="state_dict"),
|
||||
dict(name="on_save_checkpoint", args=(ANY,)),
|
||||
dict(name="teardown", kwargs=dict(stage="fit")),
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue