Support checkpoint hooks on data module (#3563)

* Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter

* Store a reference to the trainer on the datamodule

Fixes #3682

* Update data_connector.py

* Update data_connector.py

* Update test_datamodules.py

* Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter

* support checkpoint hooks for datamodule

refactor on_{save/load}_checkpoint to a separate hook class that both the lightning module and data module inherit
add spots in callback connector to call new datamodule hooks if available

* hooks formatting

* Update hooks.py

* Update checkpoint_connector.py

* Update lightning.py

* update based on upstream/master

checkout upstream/master

* Update checkpoint_connector.py

* add tests

* undo format revert

* Updated CHANGELOG.md

* add checkpoint hooks

* add Dict type

* import CheckpointHooks
This commit is contained in:
ananthsub 2020-09-29 10:51:44 -07:00 committed by GitHub
parent c14928a72a
commit 3dcf7130c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 68 deletions

View File

@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled optimizers setup during testing ([#3059](https://github.com/PyTorchLightning/pytorch-lightning/pull/3059))
- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)
### Changed
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))

View File

@ -19,7 +19,7 @@ from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union
import torch
from pytorch_lightning.core.hooks import DataHooks
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only
from torch.utils.data import DataLoader
@ -92,7 +92,7 @@ def track_data_hook_calls(fn):
return wrapped_fn
class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union
from typing import Any, Dict, List, Union
import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
@ -596,3 +596,48 @@ class DataHooks:
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return move_data_to_device(batch, device)
class CheckpointHooks:
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning to restore your model.
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
Args:
checkpoint: Loaded checkpoint
Example:
.. code-block:: python
def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note:
Lightning auto-restores global step, epoch, and train state including amp scaling.
There is no need for you to restore anything regarding training.
"""
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning when saving a checkpoint to give you a chance to store anything
else you might want to save.
Args:
checkpoint: Checkpoint to be saved
Example:
.. code-block:: python
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note:
Lightning saves all aspects of training (epoch, global step, etc...)
including amp scaling.
There is no need for you to store anything about training.
"""

View File

@ -25,7 +25,7 @@ import torch
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import DataHooks, ModelHooks
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import EvalResult, TrainResult
@ -52,12 +52,25 @@ else:
XLA_AVAILABLE = True
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module):
class LightningModule(
ABC,
DeviceDtypeModuleMixin,
GradInformation,
ModelIO,
ModelHooks,
DataHooks,
CheckpointHooks,
Module,
):
# Below is for property support of JIT in PyTorch 1.7
# since none of them is important when using JIT, we are going to ignore them.
# https://github.com/pytorch/pytorch/commit/e7d782e724c76bb0572023d52ee7438a40a7a262#diff-ff4f8670281cd1eb4e09329cc1dcb43b
__ignored_properties__ = ['datamodule', 'example_input_array',
'hparams', 'on_gpu'] + DeviceDtypeModuleMixin.__ignored_properties__
__ignored_properties__ = [
"datamodule",
"example_input_array",
"hparams",
"on_gpu",
] + DeviceDtypeModuleMixin.__ignored_properties__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -1400,49 +1413,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
self.train()
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning to restore your model.
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
Args:
checkpoint: Loaded checkpoint
Example:
.. code-block:: python
def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note:
Lightning auto-restores global step, epoch, and train state including amp scaling.
There is no need for you to restore anything regarding training.
"""
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
r"""
Called by Lightning when saving a checkpoint to give you a chance to store anything
else you might want to save.
Args:
checkpoint: Checkpoint to be saved
Example:
.. code-block:: python
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note:
Lightning saves all aspects of training (epoch, global step, etc...)
including amp scaling.
There is no need for you to store anything about training.
"""
def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
Implement this to override the default items displayed in the progress bar.

View File

@ -104,6 +104,9 @@ class CheckpointConnector:
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# give the datamodule a chance to load something
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
@ -294,6 +297,8 @@ class CheckpointConnector:
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
return checkpoint

View File

@ -1,34 +1,39 @@
import os
from torch.utils.data import random_split, DataLoader
from typing import Any, Dict, Optional
from pytorch_lightning.core.datamodule import LightningDataModule
from tests.base.datasets import TrialMNIST, MNIST
from tests.base.datasets import MNIST, TrialMNIST
from torch.utils.data import DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler
class TrialMNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = './'):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str = None):
def setup(self, stage: Optional[str] = None):
if stage == 'fit' or stage is None:
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
if stage == "fit" or stage is None:
mnist_full = TrialMNIST(
root=self.data_dir, train=True, num_samples=64, download=True
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape
if stage == 'test' or stage is None:
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
if stage == "test" or stage is None:
self.mnist_test = TrialMNIST(
root=self.data_dir, train=False, num_samples=64, download=True
)
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)
self.non_picklable = lambda x: x**2
self.non_picklable = lambda x: x ** 2
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
@ -39,10 +44,16 @@ class TrialMNISTDataModule(LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint[self.__class__.__name__] = self.__class__.__name__
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
class MNISTDataModule(LightningDataModule):
def __init__(
self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False
self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False
) -> None:
super().__init__()
@ -60,16 +71,20 @@ class MNISTDataModule(LightningDataModule):
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))
def setup(self, stage: str = None):
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == 'fit' or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))
if stage == "fit" or stage is None:
self.mnist_train = MNIST(
self.data_dir, train=True, normalize=(0.1307, 0.3081)
)
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, normalize=(0.1307, 0.3081)
)
def train_dataloader(self):
dist_sampler = None
@ -77,7 +92,10 @@ class MNISTDataModule(LightningDataModule):
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)
return DataLoader(
self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False
self.mnist_train,
batch_size=self.batch_size,
sampler=dist_sampler,
shuffle=False,
)
def test_dataloader(self):

View File

@ -216,6 +216,25 @@ def test_train_val_loop_only(tmpdir):
assert trainer.logger_connector.callback_metrics['loss'] < 0.6
def test_dm_checkpoint_save(tmpdir):
reset_seed()
dm = TrialMNISTDataModule(tmpdir)
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
)
# fit model
result = trainer.fit(model, dm)
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
def test_test_loop_only(tmpdir):
reset_seed()