updated tests and docs

This commit is contained in:
William Falcon 2019-08-07 07:09:37 -04:00
parent d3f19c8321
commit 47a691f158
5 changed files with 93 additions and 2 deletions

View File

@ -259,6 +259,7 @@ tensorboard --logdir /some/path
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
###### Computing cluster (SLURM)

View File

@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint(
trainer = Trainer(checkpoint_callback=checkpoint_callback)
```
---
### Restoring training session
You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint.
``` {.python}
from test_tube import Experiment
exp = Experiment(version=a_previous_version_with_a_saved_checkpoint)
Trainer(experiment=exp)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
# the trainer is now restored
```

View File

@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do:
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
**Computing cluster (SLURM)**

View File

@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file.
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
###### Computing cluster (SLURM)

View File

@ -52,6 +52,77 @@ def test_amp_gpu_ddp():
run_gpu_model_test(trainer_options, model, hparams)
def test_cpu_restore_training():
"""
Verify continue training session on CPU
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
test_exp_version = 10
exp = get_exp(False, version=test_exp_version)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=1,
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_step = trainer.global_step
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# predict with trained model before saving
# make a prediction
for batch in model.test_dataloader:
break
x, y = batch
x = x.view(x.size(0), -1)
model.eval()
pred_before_saving = model(x)
# wipe-out trainer and model
# retrain with not much data... this simulates picking training back up after slurm
# we want to see if the weights come back correctly
new_exp = get_exp(False, version=test_exp_version)
trainer_options = dict(
max_nb_epochs=1,
experiment=new_exp,
checkpoint_callback=ModelCheckpoint(save_dir),
)
trainer = Trainer(**trainer_options)
model = LightningTestModel(hparams)
# set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():
assert trainer.global_step == real_global_step and trainer.global_step > 0
# predict with loaded model to make sure answers are the same
trainer.model.eval()
new_pred = trainer.model(x)
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
model.on_epoch_start = assert_pred_same
# by calling fit again, we trigger training, loading weights from the cluster
# and our hook to predict using current model before any more weight updates
trainer.fit(model)
clear_save_dir()
def test_cpu_slurm_save_load():
"""
Verify model save/load/checkpoint on CPU
@ -610,10 +681,10 @@ def get_model():
return model, hparams
def get_exp(debug=True):
def get_exp(debug=True, version=None):
# set up exp object without actually saving logs
root_dir = os.path.dirname(os.path.realpath(__file__))
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir')
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version)
return exp