Fix checkpoint values when saving and resetting the tuner state (#11518)

This commit is contained in:
Carlos Mocholí 2022-01-20 19:54:40 +01:00 committed by GitHub
parent 1a25363617
commit 075b8801c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 67 additions and 88 deletions

View File

@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))
- The tuner now uses the checkpoint connector to copy and restore its state ([#11518](https://github.com/PyTorchLightning/pytorch-lightning/pull/11518))
- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))

View File

@ -222,7 +222,7 @@ class CheckpointConnector:
assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING:
if state_dict is not None:
if self.trainer.state.fn == TrainerFn.FITTING:
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:

View File

@ -14,14 +14,13 @@
import logging
import os
import uuid
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
@ -59,15 +58,17 @@ def scale_batch_size(
" Please disable the feature or incorporate the dataloader into the model."
)
# Arguments we adjust during the batch size finder, save for restoring
__scale_batch_dump_params(trainer)
# Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.global_step += 1
params = __scale_batch_dump_params(trainer)
# Set to values that are required by the algorithm
__scale_batch_reset_params(trainer, model, steps_per_trial)
# Save initial model, that is loaded after batch size is found
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
trainer.save_checkpoint(str(save_path))
__scale_batch_reset_params(trainer, steps_per_trial)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()
@ -85,37 +86,28 @@ def scale_batch_size(
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
# Restore initial state of model
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path))
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
trainer.checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__scale_batch_restore_params(trainer, params)
# Finish by resetting variables so trainer is ready to fit model
__scale_batch_restore_params(trainer)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
return new_size
def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
"auto_lr_find": trainer.auto_lr_find,
"current_epoch": trainer.current_epoch,
"global_step": trainer.global_step,
"max_steps": trainer.max_steps,
def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
return {
"max_steps": trainer.fit_loop.max_steps,
"logger": trainer.logger,
"callbacks": trainer.callbacks,
"checkpoint_callback": trainer.checkpoint_callback,
"auto_scale_batch_size": trainer.auto_scale_batch_size,
"auto_lr_find": trainer.auto_lr_find,
"limit_train_batches": trainer.limit_train_batches,
"model": trainer.model,
}
def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None:
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.fit_loop.current_epoch = 0
@ -123,21 +115,15 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.callbacks = [] # not needed before full run
trainer.limit_train_batches = 1.0
trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving
trainer.model = model # required for saving
def __scale_batch_restore_params(trainer: "pl.Trainer") -> None:
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
trainer.logger = trainer.__dumped_params["logger"]
trainer.callbacks = trainer.__dumped_params["callbacks"]
trainer.auto_scale_batch_size = trainer.__dumped_params["auto_scale_batch_size"]
trainer.limit_train_batches = trainer.__dumped_params["limit_train_batches"]
trainer.model = trainer.__dumped_params["model"]
del trainer.__dumped_params
def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
trainer.auto_scale_batch_size = params["auto_scale_batch_size"]
trainer.auto_lr_find = params["auto_lr_find"]
trainer.fit_loop.max_steps = params["max_steps"]
trainer.logger = params["logger"]
trainer.callbacks = params["callbacks"]
trainer.limit_train_batches = params["limit_train_batches"]
def _run_power_scaling(

View File

@ -16,7 +16,7 @@ import logging
import os
import uuid
from functools import wraps
from typing import Optional, Sequence
from typing import Any, Dict, Optional, Sequence
import numpy as np
import torch
@ -27,7 +27,6 @@ from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.types import LRSchedulerConfig
@ -203,36 +202,25 @@ def lr_find(
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)
save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt")
# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.global_step += 1
params = __lr_finder_dump_params(trainer)
__lr_finder_dump_params(trainer, model)
# Prevent going into infinite loop
trainer.auto_lr_find = False
# Set to values that are required by the algorithm
__lr_finder_reset_params(trainer, num_training, early_stop_threshold)
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training
# Disable standard progress bar for fit
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()
# Required for saving the model
trainer.optimizers, trainer.strategy.lr_schedulers = [], []
trainer.model = model
# Dump model checkpoint
trainer.save_checkpoint(str(save_path))
# Configure optimizer and scheduler
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)
@ -247,15 +235,11 @@ def lr_find(
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
# Reset model state
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path))
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
# Restore initial state of model
trainer.checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__lr_finder_restore_params(trainer, params)
# Finish by resetting variables so trainer is ready to fit model
__lr_finder_restore_params(trainer, model)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
@ -270,27 +254,31 @@ def lr_find(
return lr_finder
def __lr_finder_dump_params(trainer, model):
# Prevent going into infinite loop
trainer.__dumped_params = {
def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
return {
"auto_lr_find": trainer.auto_lr_find,
"callbacks": trainer.callbacks,
"logger": trainer.logger,
"global_step": trainer.global_step,
"max_steps": trainer.max_steps,
"checkpoint_callback": trainer.checkpoint_callback,
"current_epoch": trainer.current_epoch,
"max_steps": trainer.fit_loop.max_steps,
}
def __lr_finder_restore_params(trainer, model):
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
trainer.logger = trainer.__dumped_params["logger"]
trainer.callbacks = trainer.__dumped_params["callbacks"]
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
del trainer.__dumped_params
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
# avoid lr find being called multiple times
trainer.auto_lr_find = False
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training
def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
trainer.auto_lr_find = params["auto_lr_find"]
trainer.callbacks = params["callbacks"]
trainer.logger = params["logger"]
trainer.fit_loop.max_steps = params["max_steps"]
class _LRCallback(Callback):

View File

@ -41,6 +41,8 @@ class Tuner:
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
result = {}
self.trainer.strategy.connect(model)
# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, str):

View File

@ -65,7 +65,7 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after learning rate finder"
assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model"))
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".lr_find"))
def test_trainer_reset_correctly(tmpdir):

View File

@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after scaling batch size"
assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size"))
def test_trainer_reset_correctly(tmpdir):