updated tests and docs
This commit is contained in:
parent
d3f19c8321
commit
47a691f158
|
@ -259,6 +259,7 @@ tensorboard --logdir /some/path
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
###### Computing cluster (SLURM)
|
||||
|
||||
|
|
|
@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint(
|
|||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
```
|
||||
|
||||
---
|
||||
### Restoring training session
|
||||
You might want to not only load a model but also continue training it. Use this method to
|
||||
restore the trainer state as well. This will continue from the epoch and global step you last left off.
|
||||
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
|
||||
|
||||
Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint.
|
||||
``` {.python}
|
||||
from test_tube import Experiment
|
||||
|
||||
exp = Experiment(version=a_previous_version_with_a_saved_checkpoint)
|
||||
Trainer(experiment=exp)
|
||||
|
||||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
# the trainer is now restored
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do:
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
**Computing cluster (SLURM)**
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file.
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
###### Computing cluster (SLURM)
|
||||
|
||||
|
|
|
@ -52,6 +52,77 @@ def test_amp_gpu_ddp():
|
|||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_cpu_restore_training():
|
||||
"""
|
||||
Verify continue training session on CPU
|
||||
:return:
|
||||
"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
test_exp_version = 10
|
||||
exp = get_exp(False, version=test_exp_version)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
experiment=exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir)
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_step = trainer.global_step
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# predict with trained model before saving
|
||||
# make a prediction
|
||||
for batch in model.test_dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
model.eval()
|
||||
pred_before_saving = model(x)
|
||||
|
||||
# wipe-out trainer and model
|
||||
# retrain with not much data... this simulates picking training back up after slurm
|
||||
# we want to see if the weights come back correctly
|
||||
new_exp = get_exp(False, version=test_exp_version)
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
experiment=new_exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir),
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
# set the epoch start hook so we can predict before the model does the full training
|
||||
def assert_pred_same():
|
||||
assert trainer.global_step == real_global_step and trainer.global_step > 0
|
||||
|
||||
# predict with loaded model to make sure answers are the same
|
||||
trainer.model.eval()
|
||||
new_pred = trainer.model(x)
|
||||
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
||||
|
||||
model.on_epoch_start = assert_pred_same
|
||||
|
||||
# by calling fit again, we trigger training, loading weights from the cluster
|
||||
# and our hook to predict using current model before any more weight updates
|
||||
trainer.fit(model)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_cpu_slurm_save_load():
|
||||
"""
|
||||
Verify model save/load/checkpoint on CPU
|
||||
|
@ -610,10 +681,10 @@ def get_model():
|
|||
return model, hparams
|
||||
|
||||
|
||||
def get_exp(debug=True):
|
||||
def get_exp(debug=True, version=None):
|
||||
# set up exp object without actually saving logs
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir')
|
||||
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version)
|
||||
return exp
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue