Fix checkpoint values when saving and resetting the tuner state (#11518)
This commit is contained in:
parent
1a25363617
commit
075b8801c9
|
@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))
|
||||
|
||||
|
||||
- The tuner now uses the checkpoint connector to copy and restore its state ([#11518](https://github.com/PyTorchLightning/pytorch-lightning/pull/11518))
|
||||
|
||||
|
||||
- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))
|
||||
|
||||
|
||||
|
|
|
@ -222,7 +222,7 @@ class CheckpointConnector:
|
|||
|
||||
assert self.trainer.state.fn is not None
|
||||
state_dict = self._loaded_checkpoint.get("loops")
|
||||
if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING:
|
||||
if state_dict is not None:
|
||||
if self.trainer.state.fn == TrainerFn.FITTING:
|
||||
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
|
||||
elif self.trainer.state.fn == TrainerFn.VALIDATING:
|
||||
|
|
|
@ -14,14 +14,13 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.data import has_len_all_ranks
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
||||
|
@ -59,15 +58,17 @@ def scale_batch_size(
|
|||
" Please disable the feature or incorporate the dataloader into the model."
|
||||
)
|
||||
|
||||
# Arguments we adjust during the batch size finder, save for restoring
|
||||
__scale_batch_dump_params(trainer)
|
||||
# Save initial model, that is loaded after batch size is found
|
||||
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
|
||||
trainer.fit_loop.current_epoch -= 1
|
||||
trainer.fit_loop.global_step -= 1
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
trainer.fit_loop.current_epoch += 1
|
||||
trainer.fit_loop.global_step += 1
|
||||
params = __scale_batch_dump_params(trainer)
|
||||
|
||||
# Set to values that are required by the algorithm
|
||||
__scale_batch_reset_params(trainer, model, steps_per_trial)
|
||||
|
||||
# Save initial model, that is loaded after batch size is found
|
||||
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
|
||||
trainer.save_checkpoint(str(save_path))
|
||||
__scale_batch_reset_params(trainer, steps_per_trial)
|
||||
|
||||
if trainer.progress_bar_callback:
|
||||
trainer.progress_bar_callback.disable()
|
||||
|
@ -85,37 +86,28 @@ def scale_batch_size(
|
|||
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
|
||||
|
||||
# Restore initial state of model
|
||||
if trainer.is_global_zero:
|
||||
trainer.checkpoint_connector.restore(str(save_path))
|
||||
fs = get_filesystem(str(save_path))
|
||||
if fs.exists(save_path):
|
||||
fs.rm(save_path)
|
||||
trainer.checkpoint_connector.restore(ckpt_path)
|
||||
trainer.strategy.remove_checkpoint(ckpt_path)
|
||||
__scale_batch_restore_params(trainer, params)
|
||||
|
||||
# Finish by resetting variables so trainer is ready to fit model
|
||||
__scale_batch_restore_params(trainer)
|
||||
if trainer.progress_bar_callback:
|
||||
trainer.progress_bar_callback.enable()
|
||||
|
||||
return new_size
|
||||
|
||||
|
||||
def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
|
||||
# Prevent going into infinite loop
|
||||
trainer.__dumped_params = {
|
||||
"auto_lr_find": trainer.auto_lr_find,
|
||||
"current_epoch": trainer.current_epoch,
|
||||
"global_step": trainer.global_step,
|
||||
"max_steps": trainer.max_steps,
|
||||
def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
|
||||
return {
|
||||
"max_steps": trainer.fit_loop.max_steps,
|
||||
"logger": trainer.logger,
|
||||
"callbacks": trainer.callbacks,
|
||||
"checkpoint_callback": trainer.checkpoint_callback,
|
||||
"auto_scale_batch_size": trainer.auto_scale_batch_size,
|
||||
"auto_lr_find": trainer.auto_lr_find,
|
||||
"limit_train_batches": trainer.limit_train_batches,
|
||||
"model": trainer.model,
|
||||
}
|
||||
|
||||
|
||||
def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None:
|
||||
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
|
||||
trainer.auto_scale_batch_size = None # prevent recursion
|
||||
trainer.auto_lr_find = False # avoid lr find being called multiple times
|
||||
trainer.fit_loop.current_epoch = 0
|
||||
|
@ -123,21 +115,15 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule
|
|||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||
trainer.callbacks = [] # not needed before full run
|
||||
trainer.limit_train_batches = 1.0
|
||||
trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving
|
||||
trainer.model = model # required for saving
|
||||
|
||||
|
||||
def __scale_batch_restore_params(trainer: "pl.Trainer") -> None:
|
||||
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
|
||||
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
|
||||
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
|
||||
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
|
||||
trainer.logger = trainer.__dumped_params["logger"]
|
||||
trainer.callbacks = trainer.__dumped_params["callbacks"]
|
||||
trainer.auto_scale_batch_size = trainer.__dumped_params["auto_scale_batch_size"]
|
||||
trainer.limit_train_batches = trainer.__dumped_params["limit_train_batches"]
|
||||
trainer.model = trainer.__dumped_params["model"]
|
||||
del trainer.__dumped_params
|
||||
def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
|
||||
trainer.auto_scale_batch_size = params["auto_scale_batch_size"]
|
||||
trainer.auto_lr_find = params["auto_lr_find"]
|
||||
trainer.fit_loop.max_steps = params["max_steps"]
|
||||
trainer.logger = params["logger"]
|
||||
trainer.callbacks = params["callbacks"]
|
||||
trainer.limit_train_batches = params["limit_train_batches"]
|
||||
|
||||
|
||||
def _run_power_scaling(
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,7 +27,6 @@ from pytorch_lightning.callbacks import Callback
|
|||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
|
@ -203,36 +202,25 @@ def lr_find(
|
|||
if update_attr:
|
||||
lr_attr_name = _determine_lr_attr_name(trainer, model)
|
||||
|
||||
save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt")
|
||||
# Save initial model, that is loaded after learning rate is found
|
||||
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
|
||||
trainer.fit_loop.current_epoch -= 1
|
||||
trainer.fit_loop.global_step -= 1
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
trainer.fit_loop.current_epoch += 1
|
||||
trainer.fit_loop.global_step += 1
|
||||
params = __lr_finder_dump_params(trainer)
|
||||
|
||||
__lr_finder_dump_params(trainer, model)
|
||||
|
||||
# Prevent going into infinite loop
|
||||
trainer.auto_lr_find = False
|
||||
# Set to values that are required by the algorithm
|
||||
__lr_finder_reset_params(trainer, num_training, early_stop_threshold)
|
||||
|
||||
# Initialize lr finder object (stores results)
|
||||
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
|
||||
|
||||
# Use special lr logger callback
|
||||
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
|
||||
|
||||
# No logging
|
||||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||
|
||||
# Max step set to number of iterations
|
||||
trainer.fit_loop.max_steps = num_training
|
||||
|
||||
# Disable standard progress bar for fit
|
||||
if trainer.progress_bar_callback:
|
||||
trainer.progress_bar_callback.disable()
|
||||
|
||||
# Required for saving the model
|
||||
trainer.optimizers, trainer.strategy.lr_schedulers = [], []
|
||||
trainer.model = model
|
||||
|
||||
# Dump model checkpoint
|
||||
trainer.save_checkpoint(str(save_path))
|
||||
|
||||
# Configure optimizer and scheduler
|
||||
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)
|
||||
|
||||
|
@ -247,15 +235,11 @@ def lr_find(
|
|||
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
|
||||
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
|
||||
|
||||
# Reset model state
|
||||
if trainer.is_global_zero:
|
||||
trainer.checkpoint_connector.restore(str(save_path))
|
||||
fs = get_filesystem(str(save_path))
|
||||
if fs.exists(save_path):
|
||||
fs.rm(save_path)
|
||||
# Restore initial state of model
|
||||
trainer.checkpoint_connector.restore(ckpt_path)
|
||||
trainer.strategy.remove_checkpoint(ckpt_path)
|
||||
__lr_finder_restore_params(trainer, params)
|
||||
|
||||
# Finish by resetting variables so trainer is ready to fit model
|
||||
__lr_finder_restore_params(trainer, model)
|
||||
if trainer.progress_bar_callback:
|
||||
trainer.progress_bar_callback.enable()
|
||||
|
||||
|
@ -270,27 +254,31 @@ def lr_find(
|
|||
return lr_finder
|
||||
|
||||
|
||||
def __lr_finder_dump_params(trainer, model):
|
||||
# Prevent going into infinite loop
|
||||
trainer.__dumped_params = {
|
||||
def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
|
||||
return {
|
||||
"auto_lr_find": trainer.auto_lr_find,
|
||||
"callbacks": trainer.callbacks,
|
||||
"logger": trainer.logger,
|
||||
"global_step": trainer.global_step,
|
||||
"max_steps": trainer.max_steps,
|
||||
"checkpoint_callback": trainer.checkpoint_callback,
|
||||
"current_epoch": trainer.current_epoch,
|
||||
"max_steps": trainer.fit_loop.max_steps,
|
||||
}
|
||||
|
||||
|
||||
def __lr_finder_restore_params(trainer, model):
|
||||
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
|
||||
trainer.logger = trainer.__dumped_params["logger"]
|
||||
trainer.callbacks = trainer.__dumped_params["callbacks"]
|
||||
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
|
||||
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
|
||||
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
|
||||
del trainer.__dumped_params
|
||||
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
|
||||
# avoid lr find being called multiple times
|
||||
trainer.auto_lr_find = False
|
||||
# Use special lr logger callback
|
||||
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
|
||||
# No logging
|
||||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||
# Max step set to number of iterations
|
||||
trainer.fit_loop.max_steps = num_training
|
||||
|
||||
|
||||
def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
|
||||
trainer.auto_lr_find = params["auto_lr_find"]
|
||||
trainer.callbacks = params["callbacks"]
|
||||
trainer.logger = params["logger"]
|
||||
trainer.fit_loop.max_steps = params["max_steps"]
|
||||
|
||||
|
||||
class _LRCallback(Callback):
|
||||
|
|
|
@ -41,6 +41,8 @@ class Tuner:
|
|||
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
|
||||
result = {}
|
||||
|
||||
self.trainer.strategy.connect(model)
|
||||
|
||||
# Run auto batch size scaling
|
||||
if self.trainer.auto_scale_batch_size:
|
||||
if isinstance(self.trainer.auto_scale_batch_size, str):
|
||||
|
|
|
@ -65,7 +65,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
torch.eq(before_state_dict[key], after_state_dict[key])
|
||||
), "Model was not reset correctly after learning rate finder"
|
||||
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model"))
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".lr_find"))
|
||||
|
||||
|
||||
def test_trainer_reset_correctly(tmpdir):
|
||||
|
|
|
@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
torch.eq(before_state_dict[key], after_state_dict[key])
|
||||
), "Model was not reset correctly after scaling batch size"
|
||||
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size"))
|
||||
|
||||
|
||||
def test_trainer_reset_correctly(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue