split `restore_training_state` into logical parts [2 / 2] (#7900)
This commit is contained in:
parent
d209b68979
commit
c1eac483e9
|
@ -15,7 +15,7 @@
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -115,7 +115,7 @@ class CheckpointConnector:
|
|||
|
||||
# restore training state
|
||||
if self._loaded_checkpoint:
|
||||
self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states)
|
||||
self.restore_training_state(self._loaded_checkpoint)
|
||||
|
||||
self.resume_end()
|
||||
return True
|
||||
|
@ -135,77 +135,23 @@ class CheckpointConnector:
|
|||
# restore model state_dict
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
|
||||
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Restore trainer state.
|
||||
Model will get its change to update
|
||||
:param checkpoint:
|
||||
:return:
|
||||
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
|
||||
optimizer states and learning rate scheduler states.
|
||||
"""
|
||||
|
||||
# validation
|
||||
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
|
||||
raise KeyError(
|
||||
'Trying to restore training state but checkpoint contains only the model.'
|
||||
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
|
||||
)
|
||||
|
||||
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
|
||||
raise ValueError(
|
||||
"The checkpoint you're attempting to load follows an"
|
||||
" outdated schema. You can upgrade to the current schema by running"
|
||||
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
|
||||
" where `model.ckpt` is your checkpoint file."
|
||||
)
|
||||
|
||||
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)
|
||||
|
||||
# restore callback states
|
||||
self.trainer.on_load_checkpoint(checkpoint)
|
||||
|
||||
self.trainer.train_loop.global_step = checkpoint['global_step']
|
||||
self.trainer.train_loop.current_epoch = checkpoint['epoch']
|
||||
|
||||
# crash if max_epochs is lower then the current epoch from the checkpoint
|
||||
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
|
||||
m = f"""
|
||||
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
|
||||
but the Trainer(max_epochs={self.trainer.max_epochs})
|
||||
"""
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# Division deals with global step stepping once per accumulated batch
|
||||
# Inequality deals with different global step for odd vs even num_training_batches
|
||||
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
|
||||
expected_steps = self.trainer.num_training_batches / n_accum
|
||||
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
|
||||
rank_zero_warn(
|
||||
"You're resuming from a checkpoint that ended mid-epoch."
|
||||
" Training will start from the beginning of the next epoch."
|
||||
" This can cause unreliable results if further training is done,"
|
||||
" consider using an end of epoch checkpoint."
|
||||
)
|
||||
|
||||
if not load_optimizer_states:
|
||||
if not checkpoint:
|
||||
return
|
||||
|
||||
# restore the optimizers
|
||||
optimizer_states = checkpoint['optimizer_states']
|
||||
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
# restore precision plugin (scaler etc.)
|
||||
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)
|
||||
|
||||
# move optimizer to GPU 1 weight at a time
|
||||
# avoids OOM
|
||||
if self.trainer.root_gpu is not None:
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.cuda(self.trainer.root_gpu)
|
||||
self.restore_callbacks()
|
||||
|
||||
# restore the lr schedulers
|
||||
lr_schedulers = checkpoint['lr_schedulers']
|
||||
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
|
||||
scheduler['scheduler'].load_state_dict(lrs_state)
|
||||
# restore progress (loops etc.)
|
||||
self.restore_progress()
|
||||
|
||||
self.restore_optimizers_and_schedulers()
|
||||
|
||||
def restore_callbacks(self) -> None:
|
||||
""" Restores all callbacks from the pre-loaded checkpoint. """
|
||||
|
|
|
@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
callbacks=[early_stop_callback],
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
|
||||
with pytest.raises(MisconfigurationException, match=r'You restored a checkpoint with current_epoch'):
|
||||
new_trainer.fit(model)
|
||||
|
||||
|
||||
|
|
|
@ -391,7 +391,7 @@ def test_model_checkpoint_only_weights(tmpdir):
|
|||
|
||||
# assert restoring train state fails
|
||||
with pytest.raises(KeyError, match="checkpoint contains only the model"):
|
||||
trainer.checkpoint_connector.restore_training_state(checkpoint)
|
||||
trainer.checkpoint_connector.restore(new_weights_path)
|
||||
|
||||
|
||||
def test_model_freeze_unfreeze():
|
||||
|
|
Loading…
Reference in New Issue