From 1203094a201bd38f0b8b77d93bc39fc95f06d8ae Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Mon, 7 Feb 2022 12:13:24 -0800 Subject: [PATCH] Introduce `Stateful` DataModule (#11637) --- CHANGELOG.md | 3 +++ pytorch_lightning/core/datamodule.py | 18 +++++++++++++++- .../connectors/checkpoint_connector.py | 17 +++++++++++---- tests/core/test_datamodules.py | 21 ++++++++++++------- tests/models/test_hooks.py | 1 + 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2e23e9ace..38d1c7c3f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) +- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 15e93d44b1..02011fd7e9 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -13,7 +13,7 @@ # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" from argparse import ArgumentParser, Namespace -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -246,3 +246,19 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. + + Args: + state_dict: the datamodule state returned by ``state_dict``. + """ + pass diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4e835a415f..3560678baa 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -157,6 +157,8 @@ class CheckpointConnector: datamodule = self.trainer.datamodule if datamodule is not None: datamodule.on_load_checkpoint(self._loaded_checkpoint) + if datamodule.__class__.__qualname__ in self._loaded_checkpoint: + datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__qualname__]) def restore_model(self) -> None: """Restores a model's weights from a PyTorch Lightning checkpoint. @@ -324,7 +326,7 @@ class CheckpointConnector: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: something_cool_i_want_to_save: anything you define through model.on_save_checkpoint - LightningDataModule.__class__.__name__: pl DataModule's state + LightningDataModule.__class__.__qualname__: pl DataModule's state } """ @@ -378,10 +380,17 @@ class CheckpointConnector: else: checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) - # give the model a chance to dump a few things + # dump stateful datamodule + datamodule = self.trainer.datamodule + if datamodule is not None: + datamodule_state_dict = datamodule.state_dict() + if datamodule_state_dict: + checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict + + # on_save_checkpoint hooks model.on_save_checkpoint(checkpoint) - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_save_checkpoint(checkpoint) + if datamodule is not None: + datamodule.on_save_checkpoint(checkpoint) # TODO: remove this in v1.8. environment = self.trainer._accelerator_connector.cluster_environment diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 738083091d..22878d8340 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -196,11 +196,18 @@ def test_dm_checkpoint_save_and_load(tmpdir): return out class CustomBoringDataModule(BoringDataModule): + def state_dict(self) -> Dict[str, Any]: + return {"my": "state_dict"} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.my_state_dict = state_dict + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint[self.__class__.__name__] = self.__class__.__name__ + checkpoint[self.__class__.__qualname__].update({"on_save": "update"}) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.checkpoint_state = checkpoint.get(self.__class__.__name__) + self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy() + checkpoint[self.__class__.__qualname__].pop("on_save") reset_seed() dm = CustomBoringDataModule() @@ -220,14 +227,14 @@ def test_dm_checkpoint_save_and_load(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) - assert dm.__class__.__name__ in checkpoint - assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + assert dm.__class__.__qualname__ in checkpoint + assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"} for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn - with mock.patch.object(dm, "on_load_checkpoint") as dm_mock: - trainer._restore_modules_and_callbacks(checkpoint_path) - dm_mock.assert_called_once() + trainer._restore_modules_and_callbacks(checkpoint_path) + assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"} + assert dm.my_state_dict == {"my": "state_dict"} def test_full_loop(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e9ea468c4a..3556a18f86 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -865,6 +865,7 @@ def test_trainer_datamodule_hook_system(tmpdir): dict(name="setup", kwargs=dict(stage="fit")), dict(name="val_dataloader"), dict(name="train_dataloader"), + dict(name="state_dict"), dict(name="on_save_checkpoint", args=(ANY,)), dict(name="teardown", kwargs=dict(stage="fit")), ]