From 06ea3a05716a6d1f4a96cfb25021accdd18d8146 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Fri, 7 Jun 2024 10:52:58 -0400 Subject: [PATCH] Fix resetting epoch loop restarting flag in LearningRateFinder (#19819) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ src/lightning/pytorch/tuner/lr_finder.py | 1 + tests/tests_pytorch/tuner/test_lr_finder.py | 1 + 3 files changed, 5 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 2b76b36902..1d6c660910 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804)) +- Fixed resetting `epoch_loop.restarting` to avoid full validation run after `LearningRateFinder` ([#19818](https://github.com/Lightning-AI/pytorch-lightning/issues/19818)) + + ## [2.2.2] - 2024-04-11 ### Fixed diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 4997e23070..17a2063e50 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -301,6 +301,7 @@ def _lr_find( trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True trainer.fit_loop.epoch_loop.val_loop._combined_loader = None return lr_finder diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index a0d1d70aa3..a31be67911 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -434,6 +434,7 @@ def test_lr_finder_callback_restarting(tmp_path): super().lr_find(trainer, pl_module) pl_module._expected_max_steps = None assert not trainer.fit_loop.restarting + assert not trainer.fit_loop.epoch_loop.restarting def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: