diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6d41c846af..c2c76a915c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -207,6 +207,92 @@ class CheckpointConnector: for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state) + def restore_callbacks(self) -> None: + """ Restores all callbacks from the pre-loaded checkpoint. """ + if not self._loaded_checkpoint: + return + + if any(key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): + raise ValueError( + "The checkpoint you're attempting to load follows an" + " outdated schema. You can upgrade to the current schema by running" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) + self.trainer.on_load_checkpoint(self._loaded_checkpoint) + + def restore_progress(self) -> None: + """ + Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step + and current epoch. + """ + if not self._loaded_checkpoint: + return + + self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step'] + self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch'] + + # crash if max_epochs is lower then the current epoch from the checkpoint + if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: + raise MisconfigurationException( + f"You restored a checkpoint with current_epoch={self.trainer.current_epoch}," + f" but you have set Trainer(max_epochs={self.trainer.max_epochs})." + ) + + # Division deals with global step stepping once per accumulated batch + # Inequality deals with different global step for odd vs even num_training_batches + n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches + expected_steps = self.trainer.num_training_batches / n_accum + if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: + rank_zero_warn( + "You're resuming from a checkpoint that ended mid-epoch." + " Training will start from the beginning of the next epoch." + " This can cause unreliable results if further training is done," + " consider using an end of epoch checkpoint." + ) + + def restore_optimizers_and_schedulers(self) -> None: + """ Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ + if not self._load_optimizer_states or not self._loaded_checkpoint: + return + + # validation + if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint: + raise KeyError( + "Trying to restore training state but checkpoint contains only the model." + " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`." + ) + self.restore_optimizers() + self.restore_lr_schedulers() + + def restore_optimizers(self) -> None: + """ Restores the optimizer states from the pre-loaded checkpoint. """ + if not self._load_optimizer_states or not self._loaded_checkpoint: + return + + # restore the optimizers + optimizer_states = self._loaded_checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + + # move optimizer to GPU 1 weight at a time + # avoids OOM + if self.trainer.root_gpu is not None: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.trainer.root_gpu) + + def restore_lr_schedulers(self) -> None: + """ Restores the learning rate scheduler states from the pre-loaded checkpoint. """ + if not self._load_optimizer_states or not self._loaded_checkpoint: + return + + # restore the lr schedulers + lr_schedulers = self._loaded_checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): + scheduler['scheduler'].load_state_dict(lrs_state) + # ---------------------------------- # PRIVATE OPS # ----------------------------------