diff --git a/CHANGELOG.md b/CHANGELOG.md index 033aa89625..d125426da8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014)) +- Avoid redundant callback restore warning while tuning ([#13026](https://github.com/PyTorchLightning/pytorch-lightning/pull/13026)) + + - diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 56422eddd8..316fc5a219 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -81,14 +81,15 @@ def scale_batch_size( garbage_collection_cuda() log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") - # Restore initial state of model - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) __scale_batch_restore_params(trainer, params) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() + # Restore initial state of model + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + return new_size diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 032d2e829c..9d63c8b952 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -231,9 +231,6 @@ def lr_find( lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose - # Restore initial state of model - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) __lr_finder_restore_params(trainer, params) if trainer.progress_bar_callback: @@ -247,6 +244,10 @@ def lr_find( lightning_setattr(model, lr_attr_name, lr) log.info(f"Learning rate set to {lr}") + # Restore initial state of model + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + return lr_finder diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 62d729d3d4..daf2ecde2c 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel +from tests.helpers.utils import no_warning_call def test_error_on_more_than_1_optimizer(tmpdir): @@ -87,9 +88,11 @@ def test_trainer_reset_correctly(tmpdir): "max_steps", ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} - trainer.tuner.lr_find(model, num_training=5) - actual = {ca: getattr(trainer, ca) for ca in changed_attributes} + with no_warning_call(UserWarning, match="Please add the following callbacks"): + trainer.tuner.lr_find(model, num_training=5) + + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} assert actual == expected assert model.trainer == trainer diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 13112ee9f4..41efe15994 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf +from tests.helpers.utils import no_warning_call class BatchSizeDataModule(BoringDataModule): @@ -114,9 +115,11 @@ def test_trainer_reset_correctly(tmpdir): "global_step", ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} - trainer.tuner.scale_batch_size(model, max_trials=5) - actual = {ca: getattr(trainer, ca) for ca in changed_attributes} + with no_warning_call(UserWarning, match="Please add the following callbacks"): + trainer.tuner.scale_batch_size(model, max_trials=5) + + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} assert actual == expected