from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule
.._checkpointing:
##############################
Saving and Loading Checkpoints
##############################
Lightning provides functions to save and load checkpoints.
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
*******************
Checkpoint Contents
*******************
A Lightning checkpoint has everything needed to restore a training session including:
- 16-bit scaling factor (if using 16-bit precision training)
- Current epoch
- Global step
- LightningModule's state_dict
- State of all optimizers
- State of all learning rate schedulers
- State of all callbacks
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
- State of Loops (if using Fault-Tolerant training)
*****************
Checkpoint Saving
*****************
Automatic Saving
================
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
To change the checkpoint path pass in:
..code-block:: python
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
You can manually save checkpoints and restore your model from the checkpointed state using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint`
and :meth:`~pytorch_lightning.core.saving.ModelIO.load_from_checkpoint`.
Lightning also handles strategies where multiple processes are running, such as DDP. For example, when using the DDP strategy our training script is running across multiple devices at the same time.
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below:
..code-block:: python
trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
trainer.save_checkpoint("example.ckpt")
Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
model parallel distributed strategies such as deepspeed or sharded training.
Modifying Checkpoint When Saving and Loading
============================================
You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and
:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
Checkpointing Hyperparameters
=============================
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the ``"hyper_parameters"`` key in the checkpoint.
The :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback allows you to configure when/which/what/where checkpointing should happen. It follows the normal Callback hook structure so you can
hack it around/override its methods for your use-cases as well. Following are some of the common use-cases along with the arguments you need to specify to configure it:
How does it work?
=================
``ModelCheckpoint`` helps cover the following cases from WH-Family:
When
----
- When using iterative training which doesn't have an epoch, you can checkpoint at every ``N`` training steps by specifying ``every_n_training_steps=N``.
- You can also control the interval of epochs between checkpoints using ``every_n_epochs`` between checkpoints, to avoid slowdowns.
- You can checkpoint at a regular time interval using ``train_time_interval`` argument independent of the steps or epochs.
- In case you are monitoring a training metrics, we'd suggest using ``save_on_train_epoch_end=True`` to ensure the required metric is being accumulated correctly for creating a checkpoint.
Which
-----
- You can save the last checkpoint when training ends using ``save_last`` argument.
- You can save top-K and last-K checkpoints by configuring the ``monitor`` and ``save_top_k`` argument.
|
..testcode::
from pytorch_lightning.callbacks import ModelCheckpoint
# saves top-K checkpoints based on "val_loss" metric
- You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
|
..testcode::
from pytorch_lightning.callbacks import ModelCheckpoint
- By default, the ``ModelCheckpoint`` callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specify ``save_weights_only=True``.
Where
-----
- It gives you the ability to specify the ``dirpath`` and ``filename`` for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged using :meth:`~pytorch_lightning.core.lightning.LightningModule.log`.
|
..testcode::
from pytorch_lightning.callbacks import ModelCheckpoint
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
The :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub
and the Lightning Team will be happy to integrate/help integrate it.
-----------
***********************
Customize Checkpointing
***********************
..warning::
The Checkpoint IO API is experimental and subject to change.
Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below: