diff --git a/CHANGELOG.md b/CHANGELOG.md index 27761dd11a..7c5722e6eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled optimizers setup during testing ([#3059](https://github.com/PyTorchLightning/pytorch-lightning/pull/3059)) +- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 2234fa256e..99a95f6598 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -19,7 +19,7 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional, Tuple, Union import torch -from pytorch_lightning.core.hooks import DataHooks +from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import parsing, rank_zero_only from torch.utils.data import DataLoader @@ -92,7 +92,7 @@ def track_data_hook_calls(fn): return wrapped_fn -class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b4cfd50819..5245d90606 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, Dict, List, Union import torch from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn @@ -596,3 +596,48 @@ class DataHooks: - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ return move_data_to_device(batch, device) + + +class CheckpointHooks: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning to restore your model. + If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. + + Args: + checkpoint: Loaded checkpoint + + + Example: + .. code-block:: python + + def on_load_checkpoint(self, checkpoint): + # 99% of the time you don't need to implement this method + self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] + + Note: + Lightning auto-restores global step, epoch, and train state including amp scaling. + There is no need for you to restore anything regarding training. + """ + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning when saving a checkpoint to give you a chance to store anything + else you might want to save. + + Args: + checkpoint: Checkpoint to be saved + + Example: + .. code-block:: python + + def on_save_checkpoint(self, checkpoint): + # 99% of use cases you don't need to implement this method + checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object + + Note: + Lightning saves all aspects of training (epoch, global step, etc...) + including amp scaling. + There is no need for you to store anything about training. + + """ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8966c3677e..a62a051da9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -25,7 +25,7 @@ import torch import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation -from pytorch_lightning.core.hooks import DataHooks, ModelHooks +from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import EvalResult, TrainResult @@ -52,12 +52,25 @@ else: XLA_AVAILABLE = True -class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module): +class LightningModule( + ABC, + DeviceDtypeModuleMixin, + GradInformation, + ModelIO, + ModelHooks, + DataHooks, + CheckpointHooks, + Module, +): # Below is for property support of JIT in PyTorch 1.7 # since none of them is important when using JIT, we are going to ignore them. # https://github.com/pytorch/pytorch/commit/e7d782e724c76bb0572023d52ee7438a40a7a262#diff-ff4f8670281cd1eb4e09329cc1dcb43b - __ignored_properties__ = ['datamodule', 'example_input_array', - 'hparams', 'on_gpu'] + DeviceDtypeModuleMixin.__ignored_properties__ + __ignored_properties__ = [ + "datamodule", + "example_input_array", + "hparams", + "on_gpu", + ] + DeviceDtypeModuleMixin.__ignored_properties__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1400,49 +1413,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod self.train() - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r""" - Called by Lightning to restore your model. - If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. - - Args: - checkpoint: Loaded checkpoint - - - Example: - .. code-block:: python - - def on_load_checkpoint(self, checkpoint): - # 99% of the time you don't need to implement this method - self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] - - Note: - Lightning auto-restores global step, epoch, and train state including amp scaling. - There is no need for you to restore anything regarding training. - """ - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r""" - Called by Lightning when saving a checkpoint to give you a chance to store anything - else you might want to save. - - Args: - checkpoint: Checkpoint to be saved - - Example: - .. code-block:: python - - def on_save_checkpoint(self, checkpoint): - # 99% of use cases you don't need to implement this method - checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object - - Note: - Lightning saves all aspects of training (epoch, global step, etc...) - including amp scaling. - There is no need for you to store anything about training. - - """ - def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" Implement this to override the default items displayed in the progress bar. diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 7ebd84a428..12ab5eeb23 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -104,6 +104,9 @@ class CheckpointConnector: # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + # give the datamodule a chance to load something + if self.trainer.datamodule is not None: + self.trainer.datamodule.on_load_checkpoint(checkpoint) # give model a chance to load something model.on_load_checkpoint(checkpoint) @@ -294,6 +297,8 @@ class CheckpointConnector: # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) + if self.trainer.datamodule is not None: + self.trainer.datamodule.on_save_checkpoint(checkpoint) return checkpoint diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 234d27b721..3ea0615db6 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -1,34 +1,39 @@ import os -from torch.utils.data import random_split, DataLoader +from typing import Any, Dict, Optional from pytorch_lightning.core.datamodule import LightningDataModule -from tests.base.datasets import TrialMNIST, MNIST +from tests.base.datasets import MNIST, TrialMNIST +from torch.utils.data import DataLoader, random_split from torch.utils.data.distributed import DistributedSampler class TrialMNISTDataModule(LightningDataModule): - - def __init__(self, data_dir: str = './'): + def __init__(self, data_dir: str = "./"): super().__init__() self.data_dir = data_dir self.non_picklable = None + self.checkpoint_state: Optional[str] = None def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - def setup(self, stage: str = None): + def setup(self, stage: Optional[str] = None): - if stage == 'fit' or stage is None: - mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) + if stage == "fit" or stage is None: + mnist_full = TrialMNIST( + root=self.data_dir, train=True, num_samples=64, download=True + ) self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) self.dims = self.mnist_train[0][0].shape - if stage == 'test' or stage is None: - self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True) - self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + if stage == "test" or stage is None: + self.mnist_test = TrialMNIST( + root=self.data_dir, train=False, num_samples=64, download=True + ) + self.dims = getattr(self, "dims", self.mnist_test[0][0].shape) - self.non_picklable = lambda x: x**2 + self.non_picklable = lambda x: x ** 2 def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) @@ -39,10 +44,16 @@ class TrialMNISTDataModule(LightningDataModule): def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint[self.__class__.__name__] = self.__class__.__name__ + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.checkpoint_state = checkpoint.get(self.__class__.__name__) + class MNISTDataModule(LightningDataModule): def __init__( - self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False + self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False ) -> None: super().__init__() @@ -60,16 +71,20 @@ class MNISTDataModule(LightningDataModule): MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081)) MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081)) - def setup(self, stage: str = None): + def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 - if stage == 'fit' or stage is None: - self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081)) + if stage == "fit" or stage is None: + self.mnist_train = MNIST( + self.data_dir, train=True, normalize=(0.1307, 0.3081) + ) # Assign test dataset for use in dataloader(s) - if stage == 'test' or stage is None: - self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081)) + if stage == "test" or stage is None: + self.mnist_test = MNIST( + self.data_dir, train=False, normalize=(0.1307, 0.3081) + ) def train_dataloader(self): dist_sampler = None @@ -77,7 +92,10 @@ class MNISTDataModule(LightningDataModule): dist_sampler = DistributedSampler(self.mnist_train, shuffle=False) return DataLoader( - self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False + self.mnist_train, + batch_size=self.batch_size, + sampler=dist_sampler, + shuffle=False, ) def test_dataloader(self): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 0b08cd6718..4e4476955a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -216,6 +216,25 @@ def test_train_val_loop_only(tmpdir): assert trainer.logger_connector.callback_metrics['loss'] < 0.6 +def test_dm_checkpoint_save(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + + # fit model + result = trainer.fit(model, dm) + checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] + checkpoint = torch.load(checkpoint_path) + assert dm.__class__.__name__ in checkpoint + assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + def test_test_loop_only(tmpdir): reset_seed()