split `restore_training_state` into logical parts [1 / 2] (#7901)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-06-10 17:36:02 +02:00 committed by GitHub
parent 111287b4f9
commit d209b68979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 86 additions and 0 deletions

View File

@ -207,6 +207,92 @@ class CheckpointConnector:
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
def restore_callbacks(self) -> None:
""" Restores all callbacks from the pre-loaded checkpoint. """
if not self._loaded_checkpoint:
return
if any(key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
self.trainer.on_load_checkpoint(self._loaded_checkpoint)
def restore_progress(self) -> None:
"""
Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step
and current epoch.
"""
if not self._loaded_checkpoint:
return
self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step']
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch']
# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
raise MisconfigurationException(
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
)
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
)
def restore_optimizers_and_schedulers(self) -> None:
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return
# validation
if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint:
raise KeyError(
"Trying to restore training state but checkpoint contains only the model."
" This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
)
self.restore_optimizers()
self.restore_lr_schedulers()
def restore_optimizers(self) -> None:
""" Restores the optimizer states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return
# restore the optimizers
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
def restore_lr_schedulers(self) -> None:
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return
# restore the lr schedulers
lr_schedulers = self._loaded_checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# ----------------------------------
# PRIVATE OPS
# ----------------------------------