Improve checkpoint docs (#10916)
This commit is contained in:
parent
dcc55631f9
commit
cc42aa9401
|
@ -1,52 +0,0 @@
|
|||
Custom Checkpointing IO
|
||||
=======================
|
||||
|
||||
.. warning:: The Checkpoint IO API is experimental and subject to change.
|
||||
|
||||
Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
|
||||
that is managed by the ``TrainingTypePlugin``.
|
||||
|
||||
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` object or a ``TrainingTypePlugin`` as shown below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin
|
||||
|
||||
|
||||
class CustomCheckpointIO(CheckpointIO):
|
||||
def save_checkpoint(
|
||||
self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
|
||||
custom_checkpoint_io = CustomCheckpointIO()
|
||||
|
||||
# Pass into the Trainer object
|
||||
model = MyModel()
|
||||
trainer = Trainer(
|
||||
plugins=[custom_checkpoint_io],
|
||||
callbacks=ModelCheckpoint(save_last=True),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# pass into TrainingTypePlugin
|
||||
model = MyModel()
|
||||
device = torch.device("cpu")
|
||||
trainer = Trainer(
|
||||
plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io),
|
||||
callbacks=ModelCheckpoint(save_last=True),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
.. note::
|
||||
|
||||
Some ``TrainingTypePlugins`` do not support custom ``CheckpointIO`` as as checkpointing logic is not modifiable.
|
|
@ -0,0 +1,386 @@
|
|||
.. testsetup:: *
|
||||
|
||||
import os
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
.. _checkpointing:
|
||||
|
||||
##############################
|
||||
Saving and Loading Checkpoints
|
||||
##############################
|
||||
|
||||
Lightning provides functions to save and load checkpoints.
|
||||
|
||||
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
|
||||
|
||||
|
||||
|
||||
*******************
|
||||
Checkpoint Contents
|
||||
*******************
|
||||
|
||||
A Lightning checkpoint has everything needed to restore a training session including:
|
||||
|
||||
- 16-bit scaling factor (if using 16-bit precision training)
|
||||
- Current epoch
|
||||
- Global step
|
||||
- LightningModule's state_dict
|
||||
- State of all optimizers
|
||||
- State of all learning rate schedulers
|
||||
- State of all callbacks
|
||||
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
|
||||
- State of Loops (if using Fault-Tolerant training)
|
||||
|
||||
|
||||
*****************
|
||||
Checkpoint Saving
|
||||
*****************
|
||||
|
||||
Automatic Saving
|
||||
================
|
||||
|
||||
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
|
||||
|
||||
To change the checkpoint path pass in:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
|
||||
trainer = Trainer(default_root_dir="/your/path/to/save/checkpoints")
|
||||
|
||||
You can retrieve the checkpoint after training by calling:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss")
|
||||
trainer = Trainer(callbacks=[checkpoint_callback])
|
||||
trainer.fit(model)
|
||||
checkpoint_callback.best_model_path
|
||||
|
||||
|
||||
Disabling Checkpoints
|
||||
=====================
|
||||
|
||||
You can disable checkpointing by passing:
|
||||
|
||||
.. testcode::
|
||||
|
||||
trainer = Trainer(enable_checkpointing=False)
|
||||
|
||||
|
||||
Manual Saving
|
||||
=============
|
||||
|
||||
You can manually save checkpoints and restore your model from the checkpointed state using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint`
|
||||
and :meth:`~pytorch_lightning.core.saving.ModelIO.load_from_checkpoint`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule(hparams)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
new_model = MyLightningModule.load_from_checkpoint(checkpoint_path="example.ckpt")
|
||||
|
||||
|
||||
Manual Saving with Distributed Training Strategies
|
||||
==================================================
|
||||
|
||||
Lightning also handles strategies where multiple processes are running, such as DDP. For example, when using the DDP strategy our training script is running across multiple devices at the same time.
|
||||
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(strategy="ddp")
|
||||
model = MyLightningModule(hparams)
|
||||
trainer.fit(model)
|
||||
# Saves only on the main process
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
|
||||
Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
|
||||
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
|
||||
model parallel distributed strategies such as deepspeed or sharded training.
|
||||
|
||||
|
||||
Modifying Checkpoint When Saving and Loading
|
||||
============================================
|
||||
|
||||
You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
|
||||
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and
|
||||
:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
|
||||
|
||||
|
||||
Checkpointing Hyperparameters
|
||||
=============================
|
||||
|
||||
The Lightning checkpoint also saves the arguments passed into the LightningModule init
|
||||
under the ``"hyper_parameters"`` key in the checkpoint.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyLightningModule(LightningModule):
|
||||
def __init__(self, learning_rate, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
# all init args were saved to the checkpoint
|
||||
checkpoint = torch.load(CKPT_PATH)
|
||||
print(checkpoint["hyper_parameters"])
|
||||
# {"learning_rate": the_value}
|
||||
|
||||
|
||||
-----------
|
||||
|
||||
|
||||
******************
|
||||
Checkpoint Loading
|
||||
******************
|
||||
|
||||
To load a model along with its weights and hyperparameters use the following method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
|
||||
print(model.learning_rate)
|
||||
# prints the learning_rate you used in this checkpoint
|
||||
|
||||
model.eval()
|
||||
y_hat = model(x)
|
||||
|
||||
But if you don't want to use the hyperparameters saved in the checkpoint, pass in your own here:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class LitModel(LightningModule):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
|
||||
|
||||
you can restore the model like this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# if you train and save the model like this it will use these values when loading
|
||||
# the weights. But you can overwrite this
|
||||
LitModel(in_dim=32, out_dim=10)
|
||||
|
||||
# uses in_dim=32, out_dim=10
|
||||
model = LitModel.load_from_checkpoint(PATH)
|
||||
|
||||
# uses in_dim=128, out_dim=10
|
||||
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
|
||||
|
||||
|
||||
Restoring Training State
|
||||
========================
|
||||
|
||||
If you don't just want to load weights, but instead restore the full training,
|
||||
do the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = LitModel()
|
||||
trainer = Trainer()
|
||||
|
||||
# automatically restores model, epoch, step, LR schedulers, apex, etc...
|
||||
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
|
||||
|
||||
|
||||
-----------
|
||||
|
||||
|
||||
*******************************************
|
||||
Conditional Checkpointing (ModelCheckpoint)
|
||||
*******************************************
|
||||
|
||||
The :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback allows you to configure when/which/what/where checkpointing should happen. It follows the normal Callback hook structure so you can
|
||||
hack it around/override its methods for your use-cases as well. Following are some of the common use-cases along with the arguments you need to specify to configure it:
|
||||
|
||||
|
||||
How does it work?
|
||||
=================
|
||||
|
||||
``ModelCheckpoint`` helps cover the following cases from WH-Family:
|
||||
|
||||
When
|
||||
----
|
||||
|
||||
- When using iterative training which doesn't have an epoch, you can checkpoint at every ``N`` training steps by specifying ``every_n_training_steps=N``.
|
||||
- You can also control the interval of epochs between checkpoints using ``every_n_epochs`` between checkpoints, to avoid slowdowns.
|
||||
- You can checkpoint at a regular time interval using ``train_time_interval`` argument independent of the steps or epochs.
|
||||
- In case you are monitoring a training metrics, we'd suggest using ``save_on_train_epoch_end=True`` to ensure the required metric is being accumulated correctly for creating a checkpoint.
|
||||
|
||||
|
||||
Which
|
||||
-----
|
||||
|
||||
- You can save the last checkpoint when training ends using ``save_last`` argument.
|
||||
|
||||
- You can save top-K and last-K checkpoints by configuring the ``monitor`` and ``save_top_k`` argument.
|
||||
|
||||
|
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
# saves top-K checkpoints based on "val_loss" metric
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
save_top_k=10,
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
dirpath="my/path/",
|
||||
filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
|
||||
)
|
||||
|
||||
# saves last-K checkpoints based on "global_step" metric
|
||||
# make sure you log it inside your LightningModule
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
save_top_k=10,
|
||||
monitor="global_step",
|
||||
mode="max",
|
||||
dirpath="my/path/",
|
||||
filename="sample-mnist-{epoch:02d}-{global_step}",
|
||||
)
|
||||
|
||||
- You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
|
||||
|
||||
|
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
|
||||
# 1. calculate loss
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
# 2. log val_loss
|
||||
self.log("val_loss", loss)
|
||||
|
||||
|
||||
# 3. Init ModelCheckpoint callback, monitoring "val_loss"
|
||||
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
|
||||
|
||||
# 4. Add your callback to the callbacks list
|
||||
trainer = Trainer(callbacks=[checkpoint_callback])
|
||||
|
||||
|
||||
What
|
||||
----
|
||||
|
||||
- By default, the ``ModelCheckpoint`` callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specify ``save_weights_only=True``.
|
||||
|
||||
|
||||
Where
|
||||
-----
|
||||
|
||||
- It gives you the ability to specify the ``dirpath`` and ``filename`` for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged using :meth:`~pytorch_lightning.core.lightning.LightningModule.log`.
|
||||
|
||||
|
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="my/path/",
|
||||
filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
|
||||
)
|
||||
|
||||
|
|
||||
|
||||
The :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub
|
||||
and the Lightning Team will be happy to integrate/help integrate it.
|
||||
|
||||
|
||||
-----------
|
||||
|
||||
|
||||
***********************
|
||||
Customize Checkpointing
|
||||
***********************
|
||||
|
||||
.. warning::
|
||||
|
||||
The Checkpoint IO API is experimental and subject to change.
|
||||
|
||||
|
||||
Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
|
||||
that is managed by the ``TrainingTypePlugin``. ``CheckpointIO`` is different from :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
|
||||
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
|
||||
what's saved in the checkpoint.
|
||||
|
||||
|
||||
Built-in Checkpoint IO Plugins
|
||||
==============================
|
||||
|
||||
.. list-table:: Built-in Checkpoint IO Plugins
|
||||
:widths: 25 75
|
||||
:header-rows: 1
|
||||
|
||||
* - Plugin
|
||||
- Description
|
||||
* - :class:`~pytorch_lightning.plugins.io.TorchCheckpointIO`
|
||||
- CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
|
||||
respectively, common for most use cases.
|
||||
* - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO`
|
||||
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
|
||||
|
||||
|
||||
Custom Checkpoint IO Plugin
|
||||
===========================
|
||||
|
||||
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``TrainingTypePlugin`` as shown below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin
|
||||
|
||||
|
||||
class CustomCheckpointIO(CheckpointIO):
|
||||
def save_checkpoint(self, checkpoint, path, storage_options=None):
|
||||
...
|
||||
|
||||
def load_checkpoint(self, path, storage_options=None):
|
||||
...
|
||||
|
||||
def remove_checkpoint(self, path):
|
||||
...
|
||||
|
||||
|
||||
custom_checkpoint_io = CustomCheckpointIO()
|
||||
|
||||
# Either pass into the Trainer object
|
||||
model = MyModel()
|
||||
trainer = Trainer(
|
||||
plugins=[custom_checkpoint_io],
|
||||
callbacks=ModelCheckpoint(save_last=True),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# or pass into TrainingTypePlugin
|
||||
model = MyModel()
|
||||
device = torch.device("cpu")
|
||||
trainer = Trainer(
|
||||
plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io),
|
||||
callbacks=ModelCheckpoint(save_last=True),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
.. note::
|
||||
|
||||
Some ``TrainingTypePlugins`` like ``DeepSpeedPlugin`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
|
|
@ -1193,7 +1193,7 @@ example above, we have set ``batch_first=True``.
|
|||
sub_batch = batch[0, 0:t, ...]
|
||||
|
||||
To modify how the batch is split,
|
||||
override the :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch` method:
|
||||
override the :meth:`pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch` method:
|
||||
|
||||
.. testcode:: python
|
||||
|
||||
|
|
|
@ -51,4 +51,4 @@ You could learn more about the available filesystems with:
|
|||
print(known_implementations)
|
||||
|
||||
|
||||
You could also look into :doc:`CheckpointIO plugin <../advanced/checkpoint_io>` for more details on how to customize saving and loading checkpoints.
|
||||
You could also look into :ref:`CheckpointIO Plugin <common/checkpointing:Customize Checkpointing>` for more details on how to customize saving and loading checkpoints.
|
||||
|
|
|
@ -603,7 +603,7 @@ To disable automatic checkpointing, set this to `False`.
|
|||
|
||||
You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
|
||||
callback, and adding it to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks` list.
|
||||
See :doc:`Saving and Loading Weights <../common/weights_loading>` for how to customize checkpointing.
|
||||
See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to customize checkpointing.
|
||||
|
||||
.. testcode::
|
||||
|
||||
|
|
|
@ -1,218 +0,0 @@
|
|||
.. testsetup:: *
|
||||
|
||||
import os
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
.. _weights_loading:
|
||||
|
||||
##########################
|
||||
Saving and loading weights
|
||||
##########################
|
||||
|
||||
Lightning automates saving and loading checkpoints. Checkpoints capture the exact value of all parameters used by a model.
|
||||
|
||||
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
|
||||
|
||||
|
||||
*****************
|
||||
Checkpoint saving
|
||||
*****************
|
||||
A Lightning checkpoint has everything needed to restore a training session including:
|
||||
|
||||
- 16-bit scaling factor (apex)
|
||||
- Current epoch
|
||||
- Global step
|
||||
- Model state_dict
|
||||
- State of all optimizers
|
||||
- State of all learningRate schedulers
|
||||
- State of all callbacks
|
||||
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
|
||||
|
||||
Automatic saving
|
||||
================
|
||||
|
||||
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
|
||||
|
||||
To change the checkpoint path pass in:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
|
||||
trainer = Trainer(default_root_dir="/your/path/to/save/checkpoints")
|
||||
|
||||
You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
|
||||
|
||||
1. Calculate any metric or other quantity you wish to monitor, such as validation loss.
|
||||
2. Log the quantity using :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method, with a key such as `val_loss`.
|
||||
3. Initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback, and set `monitor` to be the key of your quantity.
|
||||
4. Pass the callback to the `callbacks` :class:`~pytorch_lightning.trainer.Trainer` flag.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
|
||||
# 1. calculate loss
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
# 2. log `val_loss`
|
||||
self.log("val_loss", loss)
|
||||
|
||||
|
||||
# 3. Init ModelCheckpoint callback, monitoring 'val_loss'
|
||||
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
|
||||
|
||||
# 4. Add your callback to the callbacks list
|
||||
trainer = Trainer(callbacks=[checkpoint_callback])
|
||||
|
||||
You can also control more advanced options, like `save_top_k`, to save the best k models and the `mode` of the monitored quantity (min/max), `save_weights_only` or `every_n_epochs` to set the interval of epochs between checkpoints, to avoid slowdowns.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
|
||||
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
dirpath="my/path/",
|
||||
filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
|
||||
save_top_k=3,
|
||||
mode="min",
|
||||
)
|
||||
|
||||
trainer = Trainer(callbacks=[checkpoint_callback])
|
||||
|
||||
You can retrieve the checkpoint after training by calling
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(dirpath="my/path/")
|
||||
trainer = Trainer(callbacks=[checkpoint_callback])
|
||||
trainer.fit(model)
|
||||
checkpoint_callback.best_model_path
|
||||
|
||||
Disabling checkpoints
|
||||
---------------------
|
||||
|
||||
You can disable checkpointing by passing
|
||||
|
||||
.. testcode::
|
||||
|
||||
trainer = Trainer(checkpoint_callback=False)
|
||||
|
||||
|
||||
The Lightning checkpoint also saves the arguments passed into the LightningModule init
|
||||
under the `hyper_parameters` key in the checkpoint.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyLightningModule(LightningModule):
|
||||
def __init__(self, learning_rate, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
# all init args were saved to the checkpoint
|
||||
checkpoint = torch.load(CKPT_PATH)
|
||||
print(checkpoint["hyper_parameters"])
|
||||
# {'learning_rate': the_value}
|
||||
|
||||
Manual saving
|
||||
=============
|
||||
You can manually save checkpoints and restore your model from the checkpointed state.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule(hparams)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
|
||||
|
||||
Manual saving with strategies
|
||||
=============================
|
||||
|
||||
Lightning also handles strategies where multiple processes are running, such as DDP. For example, when using the DDP strategy our training script is running across multiple devices at the same time.
|
||||
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(strategy="ddp")
|
||||
model = MyLightningModule(hparams)
|
||||
trainer.fit(model)
|
||||
# Saves only on the main process
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
|
||||
Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality.
|
||||
If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process.
|
||||
|
||||
******************
|
||||
Checkpoint loading
|
||||
******************
|
||||
|
||||
To load a model along with its weights, biases and hyperparameters use the following method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
|
||||
print(model.learning_rate)
|
||||
# prints the learning_rate you used in this checkpoint
|
||||
|
||||
model.eval()
|
||||
y_hat = model(x)
|
||||
|
||||
But if you don't want to use the values saved in the checkpoint, pass in your own here
|
||||
|
||||
.. testcode::
|
||||
|
||||
class LitModel(LightningModule):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
|
||||
|
||||
you can restore the model like this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# if you train and save the model like this it will use these values when loading
|
||||
# the weights. But you can overwrite this
|
||||
LitModel(in_dim=32, out_dim=10)
|
||||
|
||||
# uses in_dim=32, out_dim=10
|
||||
model = LitModel.load_from_checkpoint(PATH)
|
||||
|
||||
# uses in_dim=128, out_dim=10
|
||||
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
|
||||
|
||||
.. automethod:: pytorch_lightning.core.lightning.LightningModule.load_from_checkpoint
|
||||
:noindex:
|
||||
|
||||
Restoring Training State
|
||||
========================
|
||||
|
||||
If you don't just want to load weights, but instead restore the full training,
|
||||
do the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = LitModel()
|
||||
trainer = Trainer()
|
||||
|
||||
# automatically restores model, epoch, step, LR schedulers, apex, etc...
|
||||
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
|
|
@ -137,7 +137,7 @@ The :meth:`~pytorch_lightning.core.lightning.LightningModule.log` method has a f
|
|||
* ``sync_dist_group``: The DDP group to sync across.
|
||||
* ``add_dataloader_idx``: If True, appends the index of the current dataloader to the name
|
||||
(when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.
|
||||
* ``batch_size``: Current batch_size used for accumulating logs logged with ``on_step=True``.
|
||||
* ``batch_size``: Current batch_size used for accumulating logs logged with ``on_epoch=True``.
|
||||
This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.
|
||||
* ``rank_zero_only``: Whether the value will be logged only on rank 0. This will prevent synchronization which
|
||||
would produce a deadlock as not all processes would perform this log call.
|
||||
|
|
|
@ -73,9 +73,8 @@ PyTorch Lightning
|
|||
advanced/multi_gpu
|
||||
advanced/advanced_gpu
|
||||
advanced/mixed_precision
|
||||
common/weights_loading
|
||||
common/checkpointing
|
||||
advanced/fault_tolerant_training
|
||||
advanced/checkpoint_io
|
||||
common/optimizers
|
||||
advanced/profiler
|
||||
advanced/plugins_registry
|
||||
|
|
|
@ -207,7 +207,7 @@ The :class:`~pytorch_lightning.trainer.Trainer` automates:
|
|||
* Epoch and batch iteration
|
||||
* Calling of optimizer.step(), backward, zero_grad()
|
||||
* Calling of .eval(), enabling/disabling grads
|
||||
* :doc:`weights loading <../common/weights_loading>`
|
||||
* :doc:`checkpoint saving and loading <../common/checkpointing>`
|
||||
* Tensorboard (see :doc:`loggers <../common/loggers>` options)
|
||||
* :doc:`Multi-GPU <../advanced/multi_gpu>` support
|
||||
* :doc:`TPU <../advanced/tpu>`
|
||||
|
@ -759,7 +759,7 @@ Once you define and train your first Lightning model, you might want to try othe
|
|||
- :ref:`Automatic truncated-back-propagation-through-time <common/lightning_module:truncated_bptt_steps>`
|
||||
- :ref:`Automatically scale your batch size <advanced/training_tricks:Auto scaling of batch size>`
|
||||
- :doc:`Automatically find a good learning rate <../advanced/lr_finder>`
|
||||
- :ref:`Load checkpoints directly from S3 <common/weights_loading:Checkpoint Loading>`
|
||||
- :ref:`Load checkpoints directly from S3 <common/checkpointing:Checkpoint Loading>`
|
||||
- :doc:`Scale to massive compute clusters <../clouds/cluster>`
|
||||
- :doc:`Use multiple dataloaders per train/val/test/predict loop <../guides/data>`
|
||||
- :ref:`Use multiple optimizers to do reinforcement learning or even GANs <common/optimizers:Use multiple optimizers (like GANs)>`
|
||||
|
|
|
@ -48,7 +48,7 @@ class ModelCheckpoint(Callback):
|
|||
Save the model periodically by monitoring a quantity. Every metric logged with
|
||||
:meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in
|
||||
LightningModule is a candidate for the monitor key. For more information, see
|
||||
:ref:`weights_loading`.
|
||||
:ref:`checkpointing`.
|
||||
|
||||
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
||||
best checkpoint file and :attr:`best_model_score` to retrieve its score.
|
||||
|
@ -102,8 +102,7 @@ class ModelCheckpoint(Callback):
|
|||
``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/``
|
||||
as this will result in extra folders.
|
||||
save_weights_only: if ``True``, then only the model's weights will be
|
||||
saved (``model.save_weights(filepath)``), else the full model
|
||||
is saved (``model.save(filepath)``).
|
||||
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
||||
every_n_train_steps: Number of training steps between checkpoints.
|
||||
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
|
||||
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
|
||||
|
|
|
@ -62,9 +62,9 @@ class ModelIO:
|
|||
):
|
||||
r"""
|
||||
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
|
||||
it stores the arguments passed to `__init__` in the checkpoint under `hyper_parameters`
|
||||
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
|
||||
|
||||
Any arguments specified through \*args and \*\*kwargs will override args stored in `hyper_parameters`.
|
||||
Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
|
||||
|
@ -86,11 +86,11 @@ class ModelIO:
|
|||
These will be converted into a :class:`~dict` and passed into your
|
||||
:class:`LightningModule` for use.
|
||||
|
||||
If your model's `hparams` argument is :class:`~argparse.Namespace`
|
||||
If your model's ``hparams`` argument is :class:`~argparse.Namespace`
|
||||
and .yaml file has hierarchical structure, you need to refactor your model to treat
|
||||
`hparams` as :class:`~dict`.
|
||||
``hparams`` as :class:`~dict`.
|
||||
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
|
||||
returned by this module's state dict. Default: `True`.
|
||||
returned by this module's state dict.
|
||||
kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
|
||||
hyperparameter values.
|
||||
|
||||
|
|
|
@ -1971,6 +1971,14 @@ class Trainer(
|
|||
return resume_from_checkpoint
|
||||
|
||||
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
|
||||
r"""
|
||||
Runs routine to create a checkpoint.
|
||||
|
||||
Args:
|
||||
filepath: Path where checkpoint is saved.
|
||||
weights_only: If ``True``, will only save the model weights.
|
||||
|
||||
"""
|
||||
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue