diff --git a/CHANGELOG.md b/CHANGELOG.md index dff8a28e06..aa7c4f9b05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270)) +- The tuner now uses the checkpoint connector to copy and restore its state ([#11518](https://github.com/PyTorchLightning/pytorch-lightning/pull/11518)) + + - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c3b3f2988e..5c437bfd88 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -222,7 +222,7 @@ class CheckpointConnector: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") - if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING: + if state_dict is not None: if self.trainer.state.fn == TrainerFn.FITTING: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index d4ffe7bc97..788395f676 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -14,14 +14,13 @@ import logging import os import uuid -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error @@ -59,15 +58,17 @@ def scale_batch_size( " Please disable the feature or incorporate the dataloader into the model." ) - # Arguments we adjust during the batch size finder, save for restoring - __scale_batch_dump_params(trainer) + # Save initial model, that is loaded after batch size is found + ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") + trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.global_step -= 1 + trainer.save_checkpoint(ckpt_path) + trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.global_step += 1 + params = __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm - __scale_batch_reset_params(trainer, model, steps_per_trial) - - # Save initial model, that is loaded after batch size is found - save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt") - trainer.save_checkpoint(str(save_path)) + __scale_batch_reset_params(trainer, steps_per_trial) if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() @@ -85,37 +86,28 @@ def scale_batch_size( log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") # Restore initial state of model - if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path)) - fs = get_filesystem(str(save_path)) - if fs.exists(save_path): - fs.rm(save_path) + trainer.checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + __scale_batch_restore_params(trainer, params) - # Finish by resetting variables so trainer is ready to fit model - __scale_batch_restore_params(trainer) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> None: - # Prevent going into infinite loop - trainer.__dumped_params = { - "auto_lr_find": trainer.auto_lr_find, - "current_epoch": trainer.current_epoch, - "global_step": trainer.global_step, - "max_steps": trainer.max_steps, +def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { + "max_steps": trainer.fit_loop.max_steps, "logger": trainer.logger, "callbacks": trainer.callbacks, - "checkpoint_callback": trainer.checkpoint_callback, "auto_scale_batch_size": trainer.auto_scale_batch_size, + "auto_lr_find": trainer.auto_lr_find, "limit_train_batches": trainer.limit_train_batches, - "model": trainer.model, } -def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None: +def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.current_epoch = 0 @@ -123,21 +115,15 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 - trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving - trainer.model = model # required for saving -def __scale_batch_restore_params(trainer: "pl.Trainer") -> None: - trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"] - trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] - trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] - trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] - trainer.logger = trainer.__dumped_params["logger"] - trainer.callbacks = trainer.__dumped_params["callbacks"] - trainer.auto_scale_batch_size = trainer.__dumped_params["auto_scale_batch_size"] - trainer.limit_train_batches = trainer.__dumped_params["limit_train_batches"] - trainer.model = trainer.__dumped_params["model"] - del trainer.__dumped_params +def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_scale_batch_size = params["auto_scale_batch_size"] + trainer.auto_lr_find = params["auto_lr_find"] + trainer.fit_loop.max_steps = params["max_steps"] + trainer.logger = params["logger"] + trainer.callbacks = params["callbacks"] + trainer.limit_train_batches = params["limit_train_batches"] def _run_power_scaling( diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index a15e65ef98..0be49535e0 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -16,7 +16,7 @@ import logging import os import uuid from functools import wraps -from typing import Optional, Sequence +from typing import Any, Dict, Optional, Sequence import numpy as np import torch @@ -27,7 +27,6 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.types import LRSchedulerConfig @@ -203,36 +202,25 @@ def lr_find( if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) - save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt") + # Save initial model, that is loaded after learning rate is found + ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") + trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.global_step -= 1 + trainer.save_checkpoint(ckpt_path) + trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.global_step += 1 + params = __lr_finder_dump_params(trainer) - __lr_finder_dump_params(trainer, model) - - # Prevent going into infinite loop - trainer.auto_lr_find = False + # Set to values that are required by the algorithm + __lr_finder_reset_params(trainer, num_training, early_stop_threshold) # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - # Use special lr logger callback - trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] - - # No logging - trainer.logger = DummyLogger() if trainer.logger is not None else None - - # Max step set to number of iterations - trainer.fit_loop.max_steps = num_training - # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - # Required for saving the model - trainer.optimizers, trainer.strategy.lr_schedulers = [], [] - trainer.model = model - - # Dump model checkpoint - trainer.save_checkpoint(str(save_path)) - # Configure optimizer and scheduler trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) @@ -247,15 +235,11 @@ def lr_find( lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose - # Reset model state - if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path)) - fs = get_filesystem(str(save_path)) - if fs.exists(save_path): - fs.rm(save_path) + # Restore initial state of model + trainer.checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + __lr_finder_restore_params(trainer, params) - # Finish by resetting variables so trainer is ready to fit model - __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() @@ -270,27 +254,31 @@ def lr_find( return lr_finder -def __lr_finder_dump_params(trainer, model): - # Prevent going into infinite loop - trainer.__dumped_params = { +def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { "auto_lr_find": trainer.auto_lr_find, "callbacks": trainer.callbacks, "logger": trainer.logger, - "global_step": trainer.global_step, - "max_steps": trainer.max_steps, - "checkpoint_callback": trainer.checkpoint_callback, - "current_epoch": trainer.current_epoch, + "max_steps": trainer.fit_loop.max_steps, } -def __lr_finder_restore_params(trainer, model): - trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"] - trainer.logger = trainer.__dumped_params["logger"] - trainer.callbacks = trainer.__dumped_params["callbacks"] - trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] - trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] - trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] - del trainer.__dumped_params +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: + # avoid lr find being called multiple times + trainer.auto_lr_find = False + # Use special lr logger callback + trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] + # No logging + trainer.logger = DummyLogger() if trainer.logger is not None else None + # Max step set to number of iterations + trainer.fit_loop.max_steps = num_training + + +def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_lr_find = params["auto_lr_find"] + trainer.callbacks = params["callbacks"] + trainer.logger = params["logger"] + trainer.fit_loop.max_steps = params["max_steps"] class _LRCallback(Callback): diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 6a3a6b66d0..f64183a92b 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -41,6 +41,8 @@ class Tuner: # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = {} + self.trainer.strategy.connect(model) + # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index cad4bb2a78..62d729d3d4 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -65,7 +65,7 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after learning rate finder" - assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model")) + assert not any(f for f in os.listdir(tmpdir) if f.startswith(".lr_find")) def test_trainer_reset_correctly(tmpdir): diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index c1d1de052d..31d3dd3dd3 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after scaling batch size" - assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model")) + assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size")) def test_trainer_reset_correctly(tmpdir):