2020-05-05 02:16:54 +00:00
.. testsetup :: *
import os
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule
2020-09-14 01:04:21 +00:00
.. _weights_loading:
2020-05-05 02:16:54 +00:00
2020-10-09 23:10:25 +00:00
##########################
2020-03-02 22:12:22 +00:00
Saving and loading weights
2020-10-09 23:10:25 +00:00
##########################
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
Lightning automates saving and loading checkpoints. Checkpoints capture the exact value of all parameters used by a model.
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
***** ***** ***** **
2020-03-02 22:12:22 +00:00
Checkpoint saving
2020-10-09 23:10:25 +00:00
***** ***** ***** **
2020-03-03 02:50:38 +00:00
A Lightning checkpoint has everything needed to restore a training session including:
- 16-bit scaling factor (apex)
- Current epoch
- Global step
- Model state_dict
- State of all optimizers
- State of all learningRate schedulers
- State of all callbacks
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
Automatic saving
2020-10-09 23:10:25 +00:00
================
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
2020-03-02 22:12:22 +00:00
To change the checkpoint path pass in:
2020-06-29 01:36:46 +00:00
.. code-block :: python
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
2020-06-12 18:37:52 +00:00
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
1. Calculate any metric or other quantity you wish to monitor, such as validation loss.
2. Log the quantity using :func: `~~pytorch_lightning.core.lightning.LightningModule.log` method, with a key such as `val_loss` .
3. Initializing the :class: `~pytorch_lightning.callbacks.ModelCheckpoint` callback, and set `monitor` to be the key of your quantity.
4. Pass the callback to `checkpoint_callback` :class: `~pytorch_lightning.trainer.Trainer` flag.
.. code-block :: python
from pytorch_lightning.callbacks import ModelCheckpoint
class LitAutoEncoder(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
# 1. calculate loss
loss = F.cross_entropy(y_hat, y)
# 2. log `val_loss`
self.log('val_loss', loss)
# 3. Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
# 4. Pass your callback to checkpoint_callback trainer flag
trainer = Trainer(checkpoint_callback=checkpoint_callback)
You can also control more advanced options, like `save_top_k` , to save the best k models and the mode of the monitored quantity (min/max/auto, where the mode is automatically inferred from the name of the monitored quantity), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns.
2020-03-02 22:12:22 +00:00
2020-06-17 17:42:28 +00:00
.. code-block :: python
2020-03-02 22:12:22 +00:00
from pytorch_lightning.callbacks import ModelCheckpoint
2020-10-09 23:10:25 +00:00
class LitAutoEncoder(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
2020-03-02 22:12:22 +00:00
checkpoint_callback = ModelCheckpoint(
2020-10-09 23:10:25 +00:00
monitor='val_loss',
filepath='my/path/sample-mnist-{epoch:02d}-{val_loss:.2f}' ,
save_top_k=3,
mode='min')
2020-03-02 22:12:22 +00:00
trainer = Trainer(checkpoint_callback=checkpoint_callback)
2020-10-09 23:10:25 +00:00
You can retrieve the checkpoint after training by calling
.. code-block :: python
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
trainer.fit(model)
checkpoint_callback.best_model_path
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
Disabling checkpoints
---------------------
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
You can disable checkpointing by passing
2020-03-02 22:12:22 +00:00
2020-05-05 02:16:54 +00:00
.. testcode ::
2020-03-02 22:12:22 +00:00
2020-05-05 02:16:54 +00:00
trainer = Trainer(checkpoint_callback=False)
2020-03-02 22:12:22 +00:00
2020-05-24 22:59:08 +00:00
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the `module_arguments` key in the checkpoint.
2020-03-02 22:12:22 +00:00
2020-05-24 22:59:08 +00:00
.. code-block :: python
2020-03-02 22:12:22 +00:00
2020-05-24 22:59:08 +00:00
class MyLightningModule(LightningModule):
2020-03-02 22:12:22 +00:00
2020-05-24 22:59:08 +00:00
def __init__(self, learning_rate, *args, * *kwargs):
super().__init__()
2020-03-02 22:12:22 +00:00
2020-05-24 22:59:08 +00:00
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}
2020-03-02 22:12:22 +00:00
2020-03-03 02:50:38 +00:00
Manual saving
2020-10-09 23:10:25 +00:00
=============
2020-04-05 15:10:44 +00:00
You can manually save checkpoints and restore your model from the checkpointed state.
2020-03-03 02:50:38 +00:00
.. code-block :: python
2020-05-05 02:16:54 +00:00
model = MyLightningModule(hparams)
2020-04-05 15:10:44 +00:00
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
2020-03-03 02:50:38 +00:00
2020-10-09 23:10:25 +00:00
***** ***** ***** ***
Checkpoint loading
***** ***** ***** ***
2020-03-02 22:12:22 +00:00
2020-10-09 23:10:25 +00:00
To load a model along with its weights, biases and `module_arguments` use the following method:
2020-03-02 22:12:22 +00:00
.. code-block :: python
model = MyLightingModule.load_from_checkpoint(PATH)
2020-03-29 19:29:48 +00:00
2020-05-24 22:59:08 +00:00
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint
2020-03-30 22:28:51 +00:00
2020-05-24 22:59:08 +00:00
model.eval()
y_hat = model(x)
2020-03-30 22:28:51 +00:00
2020-05-24 22:59:08 +00:00
But if you don't want to use the values saved in the checkpoint, pass in your own here
2020-03-30 22:28:51 +00:00
2020-05-05 02:16:54 +00:00
.. testcode ::
2020-03-30 22:28:51 +00:00
2020-05-05 02:16:54 +00:00
class LitModel(LightningModule):
2020-03-30 22:28:51 +00:00
def __init__(self, in_dim, out_dim):
2020-05-24 22:59:08 +00:00
super().__init__()
2020-06-30 23:35:54 +00:00
self.save_hyperparameters()
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
2020-03-30 22:28:51 +00:00
you can restore the model like this
.. code-block :: python
2020-05-24 22:59:08 +00:00
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
2020-05-05 02:16:54 +00:00
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
2020-03-30 22:28:51 +00:00
Restoring Training State
2020-10-09 23:10:25 +00:00
========================
2020-03-30 22:28:51 +00:00
If you don't just want to load weights, but instead restore the full training,
do the following:
.. code-block :: python
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
2020-03-29 19:29:48 +00:00
2020-03-30 22:28:51 +00:00
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)