From 47a691f1583c317b497d389d68a4bd7b2cc814d3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Aug 2019 07:09:37 -0400 Subject: [PATCH] updated tests and docs --- README.md | 1 + docs/Trainer/Checkpointing.md | 17 ++++++++ docs/Trainer/index.md | 1 + docs/index.md | 1 + tests/test_models.py | 75 ++++++++++++++++++++++++++++++++++- 5 files changed, 93 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7542eeef3a..76ced2327f 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,7 @@ tensorboard --logdir /some/path - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) ###### Computing cluster (SLURM) diff --git a/docs/Trainer/Checkpointing.md b/docs/Trainer/Checkpointing.md index db0d7d501d..b0a5281d57 100644 --- a/docs/Trainer/Checkpointing.md +++ b/docs/Trainer/Checkpointing.md @@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint( trainer = Trainer(checkpoint_callback=checkpoint_callback) ``` +--- +### Restoring training session +You might want to not only load a model but also continue training it. Use this method to +restore the trainer state as well. This will continue from the epoch and global step you last left off. +However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). + +Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint. +``` {.python} +from test_tube import Experiment + +exp = Experiment(version=a_previous_version_with_a_saved_checkpoint) +Trainer(experiment=exp) + +trainer = Trainer(checkpoint_callback=checkpoint_callback) +# the trainer is now restored +``` + diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 88d85abe3b..7d363a2191 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do: - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) **Computing cluster (SLURM)** diff --git a/docs/index.md b/docs/index.md index 0107897dd6..45f3e25cc2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file. - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) ###### Computing cluster (SLURM) diff --git a/tests/test_models.py b/tests/test_models.py index 73ba2e43c1..1bc6b8dd0a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -52,6 +52,77 @@ def test_amp_gpu_ddp(): run_gpu_model_test(trainer_options, model, hparams) +def test_cpu_restore_training(): + """ + Verify continue training session on CPU + :return: + """ + hparams = get_hparams() + model = LightningTestModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + test_exp_version = 10 + exp = get_exp(False, version=test_exp_version) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + real_global_step = trainer.global_step + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # predict with trained model before saving + # make a prediction + for batch in model.test_dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + model.eval() + pred_before_saving = model(x) + + # wipe-out trainer and model + # retrain with not much data... this simulates picking training back up after slurm + # we want to see if the weights come back correctly + new_exp = get_exp(False, version=test_exp_version) + trainer_options = dict( + max_nb_epochs=1, + experiment=new_exp, + checkpoint_callback=ModelCheckpoint(save_dir), + ) + trainer = Trainer(**trainer_options) + model = LightningTestModel(hparams) + + # set the epoch start hook so we can predict before the model does the full training + def assert_pred_same(): + assert trainer.global_step == real_global_step and trainer.global_step > 0 + + # predict with loaded model to make sure answers are the same + trainer.model.eval() + new_pred = trainer.model(x) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 + + model.on_epoch_start = assert_pred_same + + # by calling fit again, we trigger training, loading weights from the cluster + # and our hook to predict using current model before any more weight updates + trainer.fit(model) + + clear_save_dir() + + def test_cpu_slurm_save_load(): """ Verify model save/load/checkpoint on CPU @@ -610,10 +681,10 @@ def get_model(): return model, hparams -def get_exp(debug=True): +def get_exp(debug=True, version=None): # set up exp object without actually saving logs root_dir = os.path.dirname(os.path.realpath(__file__)) - exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir') + exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version) return exp