split `restore_training_state` into logical parts [1 / 2] (#7901)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
111287b4f9
commit
d209b68979
|
@ -207,6 +207,92 @@ class CheckpointConnector:
|
|||
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
|
||||
scheduler['scheduler'].load_state_dict(lrs_state)
|
||||
|
||||
def restore_callbacks(self) -> None:
|
||||
""" Restores all callbacks from the pre-loaded checkpoint. """
|
||||
if not self._loaded_checkpoint:
|
||||
return
|
||||
|
||||
if any(key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
|
||||
raise ValueError(
|
||||
"The checkpoint you're attempting to load follows an"
|
||||
" outdated schema. You can upgrade to the current schema by running"
|
||||
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
|
||||
" where `model.ckpt` is your checkpoint file."
|
||||
)
|
||||
self.trainer.on_load_checkpoint(self._loaded_checkpoint)
|
||||
|
||||
def restore_progress(self) -> None:
|
||||
"""
|
||||
Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step
|
||||
and current epoch.
|
||||
"""
|
||||
if not self._loaded_checkpoint:
|
||||
return
|
||||
|
||||
self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step']
|
||||
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch']
|
||||
|
||||
# crash if max_epochs is lower then the current epoch from the checkpoint
|
||||
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
|
||||
raise MisconfigurationException(
|
||||
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
|
||||
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
|
||||
)
|
||||
|
||||
# Division deals with global step stepping once per accumulated batch
|
||||
# Inequality deals with different global step for odd vs even num_training_batches
|
||||
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
|
||||
expected_steps = self.trainer.num_training_batches / n_accum
|
||||
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
|
||||
rank_zero_warn(
|
||||
"You're resuming from a checkpoint that ended mid-epoch."
|
||||
" Training will start from the beginning of the next epoch."
|
||||
" This can cause unreliable results if further training is done,"
|
||||
" consider using an end of epoch checkpoint."
|
||||
)
|
||||
|
||||
def restore_optimizers_and_schedulers(self) -> None:
|
||||
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """
|
||||
if not self._load_optimizer_states or not self._loaded_checkpoint:
|
||||
return
|
||||
|
||||
# validation
|
||||
if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint:
|
||||
raise KeyError(
|
||||
"Trying to restore training state but checkpoint contains only the model."
|
||||
" This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
|
||||
)
|
||||
self.restore_optimizers()
|
||||
self.restore_lr_schedulers()
|
||||
|
||||
def restore_optimizers(self) -> None:
|
||||
""" Restores the optimizer states from the pre-loaded checkpoint. """
|
||||
if not self._load_optimizer_states or not self._loaded_checkpoint:
|
||||
return
|
||||
|
||||
# restore the optimizers
|
||||
optimizer_states = self._loaded_checkpoint['optimizer_states']
|
||||
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
|
||||
# move optimizer to GPU 1 weight at a time
|
||||
# avoids OOM
|
||||
if self.trainer.root_gpu is not None:
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.cuda(self.trainer.root_gpu)
|
||||
|
||||
def restore_lr_schedulers(self) -> None:
|
||||
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """
|
||||
if not self._load_optimizer_states or not self._loaded_checkpoint:
|
||||
return
|
||||
|
||||
# restore the lr schedulers
|
||||
lr_schedulers = self._loaded_checkpoint['lr_schedulers']
|
||||
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
|
||||
scheduler['scheduler'].load_state_dict(lrs_state)
|
||||
|
||||
# ----------------------------------
|
||||
# PRIVATE OPS
|
||||
# ----------------------------------
|
||||
|
|
Loading…
Reference in New Issue