split `restore_training_state` into logical parts [2 / 2] (#7900)

This commit is contained in:
Adrian Wälchli 2021-06-10 21:54:21 +02:00 committed by GitHub
parent d209b68979
commit c1eac483e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 69 deletions

View File

@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import torch
@ -115,7 +115,7 @@ class CheckpointConnector:
# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states)
self.restore_training_state(self._loaded_checkpoint)
self.resume_end()
return True
@ -135,77 +135,23 @@ class CheckpointConnector:
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])
def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
# validation
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)
# restore callback states
self.trainer.on_load_checkpoint(checkpoint)
self.trainer.train_loop.global_step = checkpoint['global_step']
self.trainer.train_loop.current_epoch = checkpoint['epoch']
# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
"""
raise MisconfigurationException(m)
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
)
if not load_optimizer_states:
if not checkpoint:
return
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
self.restore_callbacks()
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# restore progress (loops etc.)
self.restore_progress()
self.restore_optimizers_and_schedulers()
def restore_callbacks(self) -> None:
""" Restores all callbacks from the pre-loaded checkpoint. """

View File

@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
callbacks=[early_stop_callback],
)
with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
with pytest.raises(MisconfigurationException, match=r'You restored a checkpoint with current_epoch'):
new_trainer.fit(model)

View File

@ -391,7 +391,7 @@ def test_model_checkpoint_only_weights(tmpdir):
# assert restoring train state fails
with pytest.raises(KeyError, match="checkpoint contains only the model"):
trainer.checkpoint_connector.restore_training_state(checkpoint)
trainer.checkpoint_connector.restore(new_weights_path)
def test_model_freeze_unfreeze():