diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 2ee77ada4c..4d439a3512 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -125,6 +125,8 @@ class TrainerIOMixin(ABC): self.early_stop_callback = None self.lr_schedulers = None self.optimizers = None + self.num_training_batches = None + self.accumulate_grad_batches = None def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, @@ -305,10 +307,9 @@ class TrainerIOMixin(ABC): self.restore_training_state(checkpoint) def dump_checkpoint(self): - checkpoint = { - 'epoch': self.current_epoch, - 'global_step': self.global_step + 'epoch': self.current_epoch + 1, + 'global_step': self.global_step + 1, } if self.checkpoint_callback is not None and self.checkpoint_callback is not False: @@ -388,6 +389,17 @@ class TrainerIOMixin(ABC): self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] + # Division deals with global step stepping once per accumulated batch + # Inequality deals with different global step for odd vs even num_training_batches + n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches + expected_steps = self.num_training_batches / n_accum + if self.num_training_batches != 0 and self.global_step % expected_steps > 1: + warnings.warn( + "You're resuming from a checkpoint that ended mid-epoch. " + "This can cause unreliable results if further training is done, " + "consider using an end of epoch checkpoint. " + ) + # restore the optimizers optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.optimizers, optimizer_states): diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 92ab71364c..2bd033b097 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -214,8 +214,8 @@ def test_dp_resume(tmpdir): trainer.is_slurm_managing_tasks = True result = trainer.fit(model) - # track epoch before saving - real_global_epoch = trainer.current_epoch + # track epoch before saving. Increment since we finished the current epoch, don't want to rerun + real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy assert result == 1, 'amp + dp model failed to complete' @@ -282,7 +282,8 @@ def test_cpu_restore_training(tmpdir): # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) - real_global_epoch = trainer.current_epoch + # Increment since we've finished the current epoch, don't want to rerun + real_global_epoch = trainer.current_epoch + 1 # traning complete assert result == 1, 'amp + ddp model failed to complete' diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 283b332aa0..231ef9508a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -413,6 +413,74 @@ def test_multiple_val_dataloader(tmpdir): tutils.run_prediction(dataloader, trainer.model) +def test_resume_from_checkpoint_epoch_restored(tmpdir): + """Verify resuming from checkpoint runs the right number of epochs""" + import types + + tutils.reset_seed() + + hparams = tutils.get_hparams() + + def new_model(): + # Create a model that tracks epochs and batches seen + model = LightningTestModel(hparams) + model.num_epochs_seen = 0 + model.num_batches_seen = 0 + + def increment_epoch(self): + self.num_epochs_seen += 1 + + def increment_batch(self, _): + self.num_batches_seen += 1 + + # Bind the increment_epoch function on_epoch_end so that the + # model keeps track of the number of epochs it has seen. + model.on_epoch_end = types.MethodType(increment_epoch, model) + model.on_batch_start = types.MethodType(increment_batch, model) + return model + + model = new_model() + + trainer_options = dict( + show_progress_bar=False, + max_epochs=2, + train_percent_check=0.65, + val_percent_check=1, + checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), + logger=False, + default_save_path=tmpdir, + early_stop_callback=False, + val_check_interval=0.5, + ) + + # fit model + trainer = Trainer(**trainer_options) + trainer.fit(model) + + training_batches = trainer.num_training_batches + + assert model.num_epochs_seen == 2 + assert model.num_batches_seen == training_batches * 2 + + # Other checkpoints can be uncommented if/when resuming mid-epoch is supported + checkpoints = [ + # os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"), + os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"), + # os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"), + os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"), + ] + + for check in checkpoints: + next_model = new_model() + state = torch.load(check) + + # Resume training + trainer_options['max_epochs'] = 4 + new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) + new_trainer.fit(next_model) + assert state['global_step'] + next_model.num_batches_seen == training_batches * 4 + + def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" tutils.reset_seed()