Introduce `Stateful` DataModule (#11637)

This commit is contained in:
jjenniferdai 2022-02-07 12:13:24 -08:00 committed by GitHub
parent 43a89eb132
commit 1203094a20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 12 deletions

View File

@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
### Changed
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""LightningDataModule for loading DataLoaders with ease."""
from argparse import ArgumentParser, Namespace
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from torch.utils.data import DataLoader, Dataset, IterableDataset
@ -246,3 +246,19 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
return datamodule
def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
pass

View File

@ -157,6 +157,8 @@ class CheckpointConnector:
datamodule = self.trainer.datamodule
if datamodule is not None:
datamodule.on_load_checkpoint(self._loaded_checkpoint)
if datamodule.__class__.__qualname__ in self._loaded_checkpoint:
datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__qualname__])
def restore_model(self) -> None:
"""Restores a model's weights from a PyTorch Lightning checkpoint.
@ -324,7 +326,7 @@ class CheckpointConnector:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__name__: pl DataModule's state
LightningDataModule.__class__.__qualname__: pl DataModule's state
}
"""
@ -378,10 +380,17 @@ class CheckpointConnector:
else:
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
# give the model a chance to dump a few things
# dump stateful datamodule
datamodule = self.trainer.datamodule
if datamodule is not None:
datamodule_state_dict = datamodule.state_dict()
if datamodule_state_dict:
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict
# on_save_checkpoint hooks
model.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
if datamodule is not None:
datamodule.on_save_checkpoint(checkpoint)
# TODO: remove this in v1.8.
environment = self.trainer._accelerator_connector.cluster_environment

View File

@ -196,11 +196,18 @@ def test_dm_checkpoint_save_and_load(tmpdir):
return out
class CustomBoringDataModule(BoringDataModule):
def state_dict(self) -> Dict[str, Any]:
return {"my": "state_dict"}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.my_state_dict = state_dict
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint[self.__class__.__name__] = self.__class__.__name__
checkpoint[self.__class__.__qualname__].update({"on_save": "update"})
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy()
checkpoint[self.__class__.__qualname__].pop("on_save")
reset_seed()
dm = CustomBoringDataModule()
@ -220,14 +227,14 @@ def test_dm_checkpoint_save_and_load(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
assert dm.__class__.__qualname__ in checkpoint
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"}
for trainer_fn in TrainerFn:
trainer.state.fn = trainer_fn
with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
trainer._restore_modules_and_callbacks(checkpoint_path)
dm_mock.assert_called_once()
trainer._restore_modules_and_callbacks(checkpoint_path)
assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
assert dm.my_state_dict == {"my": "state_dict"}
def test_full_loop(tmpdir):

View File

@ -865,6 +865,7 @@ def test_trainer_datamodule_hook_system(tmpdir):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="val_dataloader"),
dict(name="train_dataloader"),
dict(name="state_dict"),
dict(name="on_save_checkpoint", args=(ANY,)),
dict(name="teardown", kwargs=dict(stage="fit")),
]