Avoid redundant callback restore warning while tuning (#13026)

This commit is contained in:
Rohit Gupta 2022-05-11 19:41:04 +05:30 committed by GitHub
parent 1ca7330e17
commit 9881bf2a2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 10 deletions

View File

@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
- Avoid redundant callback restore warning while tuning ([#13026](https://github.com/PyTorchLightning/pytorch-lightning/pull/13026))
-

View File

@ -81,14 +81,15 @@ def scale_batch_size(
garbage_collection_cuda()
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__scale_batch_restore_params(trainer, params)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
return new_size

View File

@ -231,9 +231,6 @@ def lr_find(
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__lr_finder_restore_params(trainer, params)
if trainer.progress_bar_callback:
@ -247,6 +244,10 @@ def lr_find(
lightning_setattr(model, lr_attr_name, lr)
log.info(f"Learning rate set to {lr}")
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
return lr_finder

View File

@ -22,6 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel
from tests.helpers.utils import no_warning_call
def test_error_on_more_than_1_optimizer(tmpdir):
@ -87,9 +88,11 @@ def test_trainer_reset_correctly(tmpdir):
"max_steps",
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.lr_find(model, num_training=5)
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
with no_warning_call(UserWarning, match="Please add the following callbacks"):
trainer.tuner.lr_find(model, num_training=5)
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
assert actual == expected
assert model.trainer == trainer

View File

@ -25,6 +25,7 @@ from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call
class BatchSizeDataModule(BoringDataModule):
@ -114,9 +115,11 @@ def test_trainer_reset_correctly(tmpdir):
"global_step",
]
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
trainer.tuner.scale_batch_size(model, max_trials=5)
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
with no_warning_call(UserWarning, match="Please add the following callbacks"):
trainer.tuner.scale_batch_size(model, max_trials=5)
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
assert actual == expected