lightning/docs/source/weights_loading.rst

193 lines
6.3 KiB
ReStructuredText

.. testsetup:: *
import os
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule
.. _weights_loading:
##########################
Saving and loading weights
##########################
Lightning automates saving and loading checkpoints. Checkpoints capture the exact value of all parameters used by a model.
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
*****************
Checkpoint saving
*****************
A Lightning checkpoint has everything needed to restore a training session including:
- 16-bit scaling factor (apex)
- Current epoch
- Global step
- Model state_dict
- State of all optimizers
- State of all learningRate schedulers
- State of all callbacks
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
Automatic saving
================
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
To change the checkpoint path pass in:
.. code-block:: python
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
1. Calculate any metric or other quantity you wish to monitor, such as validation loss.
2. Log the quantity using :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method, with a key such as `val_loss`.
3. Initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback, and set `monitor` to be the key of your quantity.
4. Pass the callback to `checkpoint_callback` :class:`~pytorch_lightning.trainer.Trainer` flag.
.. code-block:: python
from pytorch_lightning.callbacks import ModelCheckpoint
class LitAutoEncoder(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
# 1. calculate loss
loss = F.cross_entropy(y_hat, y)
# 2. log `val_loss`
self.log('val_loss', loss)
# 3. Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
# 4. Pass your callback to checkpoint_callback trainer flag
trainer = Trainer(checkpoint_callback=checkpoint_callback)
You can also control more advanced options, like `save_top_k`, to save the best k models and the mode of the monitored quantity (min/max/auto, where the mode is automatically inferred from the name of the monitored quantity), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns.
.. code-block:: python
from pytorch_lightning.callbacks import ModelCheckpoint
class LitAutoEncoder(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
filepath='my/path/sample-mnist-{epoch:02d}-{val_loss:.2f}' ,
save_top_k=3,
mode='min')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
You can retrieve the checkpoint after training by calling
.. code-block:: python
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
trainer.fit(model)
checkpoint_callback.best_model_path
Disabling checkpoints
---------------------
You can disable checkpointing by passing
.. testcode::
trainer = Trainer(checkpoint_callback=False)
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the `module_arguments` key in the checkpoint.
.. code-block:: python
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}
Manual saving
=============
You can manually save checkpoints and restore your model from the checkpointed state.
.. code-block:: python
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
******************
Checkpoint loading
******************
To load a model along with its weights, biases and `module_arguments` use the following method:
.. code-block:: python
model = MyLightingModule.load_from_checkpoint(PATH)
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint
model.eval()
y_hat = model(x)
But if you don't want to use the values saved in the checkpoint, pass in your own here
.. testcode::
class LitModel(LightningModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.save_hyperparameters()
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
you can restore the model like this
.. code-block:: python
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Restoring Training State
========================
If you don't just want to load weights, but instead restore the full training,
do the following:
.. code-block:: python
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)