2020-05-05 02:16:54 +00:00
|
|
|
.. testsetup:: *
|
|
|
|
|
|
|
|
import os
|
|
|
|
from pytorch_lightning.trainer.trainer import Trainer
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
|
|
|
|
|
|
|
2020-03-02 22:12:22 +00:00
|
|
|
Saving and loading weights
|
|
|
|
==========================
|
|
|
|
|
|
|
|
Lightning can automate saving and loading checkpoints.
|
|
|
|
|
|
|
|
Checkpoint saving
|
|
|
|
-----------------
|
2020-03-03 02:50:38 +00:00
|
|
|
A Lightning checkpoint has everything needed to restore a training session including:
|
|
|
|
|
|
|
|
- 16-bit scaling factor (apex)
|
|
|
|
- Current epoch
|
|
|
|
- Global step
|
|
|
|
- Model state_dict
|
|
|
|
- State of all optimizers
|
|
|
|
- State of all learningRate schedulers
|
|
|
|
- State of all callbacks
|
|
|
|
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
|
|
|
|
|
|
|
|
Automatic saving
|
|
|
|
^^^^^^^^^^^^^^^^
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
Checkpointing is enabled by default to the current working directory.
|
|
|
|
To change the checkpoint path pass in:
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-02 22:12:22 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
trainer = Trainer(default_save_path='/your/path/to/save/checkpoints')
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
To modify the behavior of checkpointing pass in your own callback.
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
|
|
|
|
# DEFAULTS used by the Trainer
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
|
filepath=os.getcwd(),
|
2020-03-28 15:30:57 +00:00
|
|
|
save_top_k=True,
|
2020-03-02 22:12:22 +00:00
|
|
|
verbose=True,
|
|
|
|
monitor='val_loss',
|
|
|
|
mode='min',
|
|
|
|
prefix=''
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
|
|
|
|
|
|
|
|
|
|
|
Or disable it by passing
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-02 22:12:22 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
trainer = Trainer(checkpoint_callback=False)
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
|
|
|
|
|
|
|
|
.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
from argparse import Namespace
|
|
|
|
|
|
|
|
# usually these come from command line args
|
2020-03-04 14:33:39 +00:00
|
|
|
args = Namespace(learning_rate=0.001)
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
# define you module to have hparams as the first arg
|
|
|
|
# this means your checkpoint will have everything that went into making
|
|
|
|
# this model (in this case, learning rate)
|
2020-05-05 02:16:54 +00:00
|
|
|
class MyLightningModule(LightningModule):
|
2020-03-02 22:12:22 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
def __init__(self, hparams, *args, **kwargs):
|
2020-03-02 22:12:22 +00:00
|
|
|
self.hparams = hparams
|
|
|
|
|
2020-03-03 02:50:38 +00:00
|
|
|
Manual saving
|
|
|
|
^^^^^^^^^^^^^
|
2020-04-05 15:10:44 +00:00
|
|
|
You can manually save checkpoints and restore your model from the checkpointed state.
|
2020-03-03 02:50:38 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
model = MyLightningModule(hparams)
|
2020-04-05 15:10:44 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
trainer.save_checkpoint("example.ckpt")
|
|
|
|
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
|
2020-03-03 02:50:38 +00:00
|
|
|
|
2020-03-02 22:12:22 +00:00
|
|
|
Checkpoint Loading
|
|
|
|
------------------
|
|
|
|
|
2020-03-30 22:28:51 +00:00
|
|
|
To load a model along with its weights, biases and hyperparameters use following method.
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MyLightingModule.load_from_checkpoint(PATH)
|
|
|
|
model.eval()
|
|
|
|
y_hat = model(x)
|
|
|
|
|
2020-03-30 22:28:51 +00:00
|
|
|
The above only works if you used `hparams` in your model definition
|
2020-03-29 19:29:48 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-30 22:28:51 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class LitModel(LightningModule):
|
2020-03-30 22:28:51 +00:00
|
|
|
|
|
|
|
def __init__(self, hparams):
|
|
|
|
self.hparams = hparams
|
|
|
|
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
|
|
|
|
|
|
|
|
But if you don't and instead pass individual parameters
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-30 22:28:51 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class LitModel(LightningModule):
|
2020-03-30 22:28:51 +00:00
|
|
|
|
|
|
|
def __init__(self, in_dim, out_dim):
|
|
|
|
self.l1 = nn.Linear(in_dim, out_dim)
|
|
|
|
|
|
|
|
you can restore the model like this
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
|
2020-03-30 22:28:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
Restoring Training State
|
|
|
|
------------------------
|
|
|
|
|
|
|
|
If you don't just want to load weights, but instead restore the full training,
|
|
|
|
do the following:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = LitModel()
|
|
|
|
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
|
2020-03-29 19:29:48 +00:00
|
|
|
|
2020-03-30 22:28:51 +00:00
|
|
|
# automatically restores model, epoch, step, LR schedulers, apex, etc...
|
|
|
|
trainer.fit(model)
|