When a model is training, the performance changes as it continues to see more data. It is a best practice to save the state of a model throughout the training process. This gives you a version of the model, *a checkpoint*, at each key point during the development of the model. Once training has completed, use the checkpoint that corresponds to the best performance you found during the training process.
Checkpoints also enable your training to resume from where it was in case the training process is interrupted.
PyTorch Lightning checkpoints are fully usable in plain PyTorch.
----
************************
Contents of a checkpoint
************************
A Lightning checkpoint contains a dump of the model's entire internal state. Unlike plain PyTorch, Lightning saves *everything* you need to restore a model even in the most complex distributed training environments.
Inside a Lightning checkpoint you'll find:
- 16-bit scaling factor (if using 16-bit precision training)
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
..code-block:: python
# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()
To change the checkpoint path use the `default_root_dir` argument:
..code-block:: python
# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")
----
*******************************
LightningModule from checkpoint
*******************************
To load a LightningModule along with its weights and hyperparameters use the following method:
..code-block:: python
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
# disable randomness, dropout, etc...
model.eval()
# predict with the model
y_hat = model(x)
----
Save hyperparameters
====================
The LightningModule allows you to automatically save all the hyperparameters passed to *init* simply by calling *self.save_hyperparameters()*.
If you used the *self.save_hyperparameters()* method in the *__init__* method of the LightningModule, you can override these and initialize the model with different hyperparameters.
In some cases, we may also pass entire PyTorch modules to the ``__init__`` method, which you don't want to save as hyperparameters due to their large size. If you didn't call ``self.save_hyperparameters()`` or ignore parameters via ``save_hyperparameters(ignore=...)``, then you must pass the missing positional arguments or keyword arguments when calling ``load_from_checkpoint`` method:
..code-block:: python
class LitAutoencoder(pl.LightningModule):
def __init__(self, encoder, decoder):
...
...
model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)