Support checkpoint hooks on data module (#3563)
* Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter * Store a reference to the trainer on the datamodule Fixes #3682 * Update data_connector.py * Update data_connector.py * Update test_datamodules.py * Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter * support checkpoint hooks for datamodule refactor on_{save/load}_checkpoint to a separate hook class that both the lightning module and data module inherit add spots in callback connector to call new datamodule hooks if available * hooks formatting * Update hooks.py * Update checkpoint_connector.py * Update lightning.py * update based on upstream/master checkout upstream/master * Update checkpoint_connector.py * add tests * undo format revert * Updated CHANGELOG.md * add checkpoint hooks * add Dict type * import CheckpointHooks
This commit is contained in:
parent
c14928a72a
commit
3dcf7130c5
|
@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Disabled optimizers setup during testing ([#3059](https://github.com/PyTorchLightning/pytorch-lightning/pull/3059))
|
||||
|
||||
- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
|
||||
|
|
|
@ -19,7 +19,7 @@ from argparse import ArgumentParser, Namespace
|
|||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core.hooks import DataHooks
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.utilities import parsing, rank_zero_only
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -92,7 +92,7 @@ def track_data_hook_calls(fn):
|
|||
return wrapped_fn
|
||||
|
||||
|
||||
class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
||||
class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
|
||||
"""
|
||||
A DataModule standardizes the training, val, test splits, data preparation and transforms.
|
||||
The main advantage is consistent data splits, data preparation and transforms across models.
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
|
||||
|
@ -596,3 +596,48 @@ class DataHooks:
|
|||
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
|
||||
"""
|
||||
return move_data_to_device(batch, device)
|
||||
|
||||
|
||||
class CheckpointHooks:
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Called by Lightning to restore your model.
|
||||
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
|
||||
|
||||
Args:
|
||||
checkpoint: Loaded checkpoint
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
# 99% of the time you don't need to implement this method
|
||||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||
|
||||
Note:
|
||||
Lightning auto-restores global step, epoch, and train state including amp scaling.
|
||||
There is no need for you to restore anything regarding training.
|
||||
"""
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Called by Lightning when saving a checkpoint to give you a chance to store anything
|
||||
else you might want to save.
|
||||
|
||||
Args:
|
||||
checkpoint: Checkpoint to be saved
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
# 99% of use cases you don't need to implement this method
|
||||
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
||||
|
||||
Note:
|
||||
Lightning saves all aspects of training (epoch, global step, etc...)
|
||||
including amp scaling.
|
||||
There is no need for you to store anything about training.
|
||||
|
||||
"""
|
||||
|
|
|
@ -25,7 +25,7 @@ import torch
|
|||
import torch.distributed as torch_distrib
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
|
||||
from pytorch_lightning.core.step_result import EvalResult, TrainResult
|
||||
|
@ -52,12 +52,25 @@ else:
|
|||
XLA_AVAILABLE = True
|
||||
|
||||
|
||||
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module):
|
||||
class LightningModule(
|
||||
ABC,
|
||||
DeviceDtypeModuleMixin,
|
||||
GradInformation,
|
||||
ModelIO,
|
||||
ModelHooks,
|
||||
DataHooks,
|
||||
CheckpointHooks,
|
||||
Module,
|
||||
):
|
||||
# Below is for property support of JIT in PyTorch 1.7
|
||||
# since none of them is important when using JIT, we are going to ignore them.
|
||||
# https://github.com/pytorch/pytorch/commit/e7d782e724c76bb0572023d52ee7438a40a7a262#diff-ff4f8670281cd1eb4e09329cc1dcb43b
|
||||
__ignored_properties__ = ['datamodule', 'example_input_array',
|
||||
'hparams', 'on_gpu'] + DeviceDtypeModuleMixin.__ignored_properties__
|
||||
__ignored_properties__ = [
|
||||
"datamodule",
|
||||
"example_input_array",
|
||||
"hparams",
|
||||
"on_gpu",
|
||||
] + DeviceDtypeModuleMixin.__ignored_properties__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -1400,49 +1413,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
self.train()
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Called by Lightning to restore your model.
|
||||
If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
|
||||
|
||||
Args:
|
||||
checkpoint: Loaded checkpoint
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
# 99% of the time you don't need to implement this method
|
||||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||
|
||||
Note:
|
||||
Lightning auto-restores global step, epoch, and train state including amp scaling.
|
||||
There is no need for you to restore anything regarding training.
|
||||
"""
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Called by Lightning when saving a checkpoint to give you a chance to store anything
|
||||
else you might want to save.
|
||||
|
||||
Args:
|
||||
checkpoint: Checkpoint to be saved
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
# 99% of use cases you don't need to implement this method
|
||||
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
||||
|
||||
Note:
|
||||
Lightning saves all aspects of training (epoch, global step, etc...)
|
||||
including amp scaling.
|
||||
There is no need for you to store anything about training.
|
||||
|
||||
"""
|
||||
|
||||
def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
|
||||
r"""
|
||||
Implement this to override the default items displayed in the progress bar.
|
||||
|
|
|
@ -104,6 +104,9 @@ class CheckpointConnector:
|
|||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# give the datamodule a chance to load something
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_load_checkpoint(checkpoint)
|
||||
# give model a chance to load something
|
||||
model.on_load_checkpoint(checkpoint)
|
||||
|
||||
|
@ -294,6 +297,8 @@ class CheckpointConnector:
|
|||
|
||||
# give the model a chance to add a few things
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_save_checkpoint(checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
|
|
@ -1,34 +1,39 @@
|
|||
import os
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from tests.base.datasets import TrialMNIST, MNIST
|
||||
from tests.base.datasets import MNIST, TrialMNIST
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class TrialMNISTDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, data_dir: str = './'):
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.non_picklable = None
|
||||
self.checkpoint_state: Optional[str] = None
|
||||
|
||||
def prepare_data(self):
|
||||
TrialMNIST(self.data_dir, train=True, download=True)
|
||||
TrialMNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
if stage == 'fit' or stage is None:
|
||||
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
|
||||
if stage == "fit" or stage is None:
|
||||
mnist_full = TrialMNIST(
|
||||
root=self.data_dir, train=True, num_samples=64, download=True
|
||||
)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
|
||||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True)
|
||||
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = TrialMNIST(
|
||||
root=self.data_dir, train=False, num_samples=64, download=True
|
||||
)
|
||||
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)
|
||||
|
||||
self.non_picklable = lambda x: x**2
|
||||
self.non_picklable = lambda x: x ** 2
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=32)
|
||||
|
@ -39,10 +44,16 @@ class TrialMNISTDataModule(LightningDataModule):
|
|||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=32)
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
checkpoint[self.__class__.__name__] = self.__class__.__name__
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
|
||||
|
||||
|
||||
class MNISTDataModule(LightningDataModule):
|
||||
def __init__(
|
||||
self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False
|
||||
self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -60,16 +71,20 @@ class MNISTDataModule(LightningDataModule):
|
|||
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
|
||||
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
# TODO: need to split using random_split once updated to torch >= 1.6
|
||||
if stage == 'fit' or stage is None:
|
||||
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))
|
||||
if stage == "fit" or stage is None:
|
||||
self.mnist_train = MNIST(
|
||||
self.data_dir, train=True, normalize=(0.1307, 0.3081)
|
||||
)
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == 'test' or stage is None:
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = MNIST(
|
||||
self.data_dir, train=False, normalize=(0.1307, 0.3081)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
dist_sampler = None
|
||||
|
@ -77,7 +92,10 @@ class MNISTDataModule(LightningDataModule):
|
|||
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)
|
||||
|
||||
return DataLoader(
|
||||
self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False
|
||||
self.mnist_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=dist_sampler,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
|
|
|
@ -216,6 +216,25 @@ def test_train_val_loop_only(tmpdir):
|
|||
assert trainer.logger_connector.callback_metrics['loss'] < 0.6
|
||||
|
||||
|
||||
def test_dm_checkpoint_save(tmpdir):
|
||||
reset_seed()
|
||||
|
||||
dm = TrialMNISTDataModule(tmpdir)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=3,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
# fit model
|
||||
result = trainer.fit(model, dm)
|
||||
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
assert dm.__class__.__name__ in checkpoint
|
||||
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
|
||||
|
||||
def test_test_loop_only(tmpdir):
|
||||
reset_seed()
|
||||
|
||||
|
|
Loading…
Reference in New Issue