diff --git a/README.md b/README.md index b2d286ece8..5dea863937 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,7 @@ tensorboard --logdir /some/path - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) ###### Computing cluster (SLURM) diff --git a/docs/Trainer/Checkpointing.md b/docs/Trainer/Checkpointing.md index db0d7d501d..b0a5281d57 100644 --- a/docs/Trainer/Checkpointing.md +++ b/docs/Trainer/Checkpointing.md @@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint( trainer = Trainer(checkpoint_callback=checkpoint_callback) ``` +--- +### Restoring training session +You might want to not only load a model but also continue training it. Use this method to +restore the trainer state as well. This will continue from the epoch and global step you last left off. +However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). + +Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint. +``` {.python} +from test_tube import Experiment + +exp = Experiment(version=a_previous_version_with_a_saved_checkpoint) +Trainer(experiment=exp) + +trainer = Trainer(checkpoint_callback=checkpoint_callback) +# the trainer is now restored +``` + diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 88d85abe3b..7d363a2191 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do: - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) **Computing cluster (SLURM)** diff --git a/docs/index.md b/docs/index.md index 0107897dd6..45f3e25cc2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file. - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) - [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics) +- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session) ###### Computing cluster (SLURM) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 574c8061c5..1309e4c5ff 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -1,5 +1,5 @@ """ -The trainer handles all the logic for running a val loop, training loop, distributing, etc... +The trainer handles all the logic for running a val loop, training loop, distributing, etc.. . """ import os @@ -247,6 +247,32 @@ class Trainer(TrainerIO): """ raise ModuleNotFoundError(msg) + def restore_state_if_existing_checkpoint(self): + # restore trainer state and model if there is a weight for this experiment + last_epoch = -1 + last_ckpt_name = None + + # find last epoch + checkpoints = os.listdir(self.checkpoint_callback.filepath) + for name in checkpoints: + # ignore hpc ckpts + if 'hpc_' in name: + continue + + if '.ckpt' in name: + epoch = name.split('epoch_')[1] + epoch = int(re.sub('[^0-9]', '' ,epoch)) + + if epoch > last_epoch: + last_epoch = epoch + last_ckpt_name = name + + # restore last checkpoint + if last_ckpt_name is not None: + last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) + self.restore(last_ckpt_path, self.on_gpu) + print(f'model and trainer restored from checkpoint: {last_ckpt_path}') + @property def data_parallel(self): return self.use_dp or self.use_ddp @@ -609,9 +635,6 @@ We recommend you switch to ddp if you want to use amp ref_model.trainer = self ref_model.experiment = self.experiment - # run tiny validation to make sure program won't crash during val - _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps) - # save exp to get started if self.proc_rank == 0: self.experiment.save() @@ -620,14 +643,23 @@ We recommend you switch to ddp if you want to use amp # if cluster resets state, the model will update with the saved weights self.model = model + # restore training and model before hpc call + self.restore_state_if_existing_checkpoint() + # enable cluster checkpointing # also restores training state + # hpc checkpoint overrides any other checkpoints loaded before if self.cluster is not None: # pragma: no cover self.enable_auto_hpc_walltime_manager() + # run tiny validation to make sure program won't crash during val + ref_model.on_sanity_check_start() + _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps) + # --------------------------- # CORE TRAINING LOOP # --------------------------- + self.__train() def __train(self): diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 00ece23482..7661d34461 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -2,6 +2,14 @@ import torch class ModelHooks(torch.nn.Module): + + def on_sanity_check_start(self): + """ + Called before starting validate + :return: + """ + pass + def on_batch_start(self, data_batch): pass diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 0765142cd5..ffd76387e7 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -60,6 +60,22 @@ class TrainerIO(object): # do the actual save torch.save(checkpoint, filepath) + def restore(self, checkpoint_path, on_gpu): + + if on_gpu: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + + # load training state (affects trainer only) + self.restore_training_state(checkpoint) + + # load model state + model = self.__get_model() + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict']) + def dump_checkpoint(self): checkpoint = { @@ -200,15 +216,15 @@ class TrainerIO(object): # call model hook model.on_hpc_load(checkpoint) - def max_ckpt_in_folder(self, path): + def max_ckpt_in_folder(self, path, name_key='ckpt_'): files = os.listdir(path) - files = [x for x in files if 'ckpt_' in x] + files = [x for x in files if name_key in x] if len(files) == 0: return 0 ckpt_vs = [] for name in files: - name = name.split('ckpt_')[-1] + name = name.split(name_key)[-1] name = re.sub('[^0-9]', '', name) ckpt_vs.append(int(name)) diff --git a/tests/test_models.py b/tests/test_models.py index cd03d8b411..746f97ef25 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,6 +26,73 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + +def test_cpu_restore_training(): + """ + Verify continue training session on CPU + :return: + """ + hparams = get_hparams() + model = LightningTestModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + test_exp_version = 10 + exp = get_exp(False, version=test_exp_version) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=2, + val_check_interval=0.50, + val_percent_check=0.2, + train_percent_check=0.2, + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + real_global_epoch = trainer.current_epoch + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # wipe-out trainer and model + # retrain with not much data... this simulates picking training back up after slurm + # we want to see if the weights come back correctly + new_exp = get_exp(False, version=test_exp_version) + trainer_options = dict( + max_nb_epochs=2, + val_check_interval=0.50, + val_percent_check=0.2, + train_percent_check=0.2, + experiment=new_exp, + checkpoint_callback=ModelCheckpoint(save_dir), + ) + trainer = Trainer(**trainer_options) + model = LightningTestModel(hparams) + + # set the epoch start hook so we can predict before the model does the full training + def assert_good_acc(): + assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 + + # if model and state loaded correctly, predictions will be good even though we + # haven't trained with the new loaded model + trainer.model.eval() + run_prediction(trainer.val_dataloader, trainer.model) + + model.on_sanity_check_start = assert_good_acc + + # by calling fit again, we trigger training, loading weights from the cluster + # and our hook to predict using current model before any more weight updates + trainer.fit(model) + + clear_save_dir() + + def test_amp_gpu_ddp(): """ Make sure DDP + AMP work @@ -56,6 +123,8 @@ def test_amp_gpu_ddp(): run_gpu_model_test(trainer_options, model, hparams) + + def test_cpu_slurm_save_load(): """ Verify model save/load/checkpoint on CPU @@ -622,10 +691,10 @@ def get_model(): return model, hparams -def get_exp(debug=True): +def get_exp(debug=True, version=None): # set up exp object without actually saving logs root_dir = os.path.dirname(os.path.realpath(__file__)) - exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir') + exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version) return exp