From c1eac483e93494a173f5db7527df0008f551c734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 10 Jun 2021 21:54:21 +0200 Subject: [PATCH] split `restore_training_state` into logical parts [2 / 2] (#7900) --- .../connectors/checkpoint_connector.py | 80 +++---------------- tests/callbacks/test_early_stopping.py | 2 +- tests/trainer/test_trainer.py | 2 +- 3 files changed, 15 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c2c76a915c..5c64dc4a87 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,7 +15,7 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch @@ -115,7 +115,7 @@ class CheckpointConnector: # restore training state if self._loaded_checkpoint: - self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states) + self.restore_training_state(self._loaded_checkpoint) self.resume_end() return True @@ -135,77 +135,23 @@ class CheckpointConnector: # restore model state_dict model.load_state_dict(checkpoint['state_dict']) - def restore_training_state(self, checkpoint, load_optimizer_states: bool = True): + def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: """ - Restore trainer state. - Model will get its change to update - :param checkpoint: - :return: + Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, + optimizer states and learning rate scheduler states. """ - - # validation - if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint): - raise KeyError( - 'Trying to restore training state but checkpoint contains only the model.' - ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' - ) - - if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) - - self.trainer.precision_plugin.on_load_checkpoint(checkpoint) - - # restore callback states - self.trainer.on_load_checkpoint(checkpoint) - - self.trainer.train_loop.global_step = checkpoint['global_step'] - self.trainer.train_loop.current_epoch = checkpoint['epoch'] - - # crash if max_epochs is lower then the current epoch from the checkpoint - if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: - m = f""" - you restored a checkpoint with current_epoch={self.trainer.current_epoch} - but the Trainer(max_epochs={self.trainer.max_epochs}) - """ - raise MisconfigurationException(m) - - # Division deals with global step stepping once per accumulated batch - # Inequality deals with different global step for odd vs even num_training_batches - n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches - expected_steps = self.trainer.num_training_batches / n_accum - if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: - rank_zero_warn( - "You're resuming from a checkpoint that ended mid-epoch." - " Training will start from the beginning of the next epoch." - " This can cause unreliable results if further training is done," - " consider using an end of epoch checkpoint." - ) - - if not load_optimizer_states: + if not checkpoint: return - # restore the optimizers - optimizer_states = checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) + # restore precision plugin (scaler etc.) + self.trainer.precision_plugin.on_load_checkpoint(checkpoint) - # move optimizer to GPU 1 weight at a time - # avoids OOM - if self.trainer.root_gpu is not None: - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.cuda(self.trainer.root_gpu) + self.restore_callbacks() - # restore the lr schedulers - lr_schedulers = checkpoint['lr_schedulers'] - for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): - scheduler['scheduler'].load_state_dict(lrs_state) + # restore progress (loops etc.) + self.restore_progress() + + self.restore_optimizers_and_schedulers() def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index e62ddb90ff..d7a6f15459 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): callbacks=[early_stop_callback], ) - with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'): + with pytest.raises(MisconfigurationException, match=r'You restored a checkpoint with current_epoch'): new_trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d5e3ea919c..76e98329d2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -391,7 +391,7 @@ def test_model_checkpoint_only_weights(tmpdir): # assert restoring train state fails with pytest.raises(KeyError, match="checkpoint contains only the model"): - trainer.checkpoint_connector.restore_training_state(checkpoint) + trainer.checkpoint_connector.restore(new_weights_path) def test_model_freeze_unfreeze():