diff --git a/CHANGELOG.md b/CHANGELOG.md index dc8e8b799e..c13ba8b8d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -304,6 +304,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941)) +- Fixed `lr_find` trying beyond `num_training` steps and suggesting a too high learning rate ([#7076](https://github.com/PyTorchLightning/pytorch-lightning/pull/7076)) + + - Fixed logger creating incorrect version folder in DDP with repeated `Trainer.fit` calls ([#7077](https://github.com/PyTorchLightning/pytorch-lightning/pull/7077)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b7f4c58444..79ed39676a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -521,7 +521,7 @@ class TrainLoop: # max steps reached, end training if ( - self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1 + self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 and self._accumulated_batches_reached() ): break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e3ccef9aa7..2d122a3d30 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -180,7 +180,7 @@ def lr_find( # Prompt if we stopped early if trainer.global_step != num_training: - log.info('LR finder stopped early due to diverging loss.') + log.info(f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) diff --git a/tests/trainer/test_lr_finder.py b/tests/tuner/test_lr_finder.py similarity index 88% rename from tests/trainer/test_lr_finder.py rename to tests/tuner/test_lr_finder.py index 44510eb161..dcefedb4a2 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -295,3 +295,45 @@ def test_lr_find_with_bs_scale(tmpdir): assert lr != before_lr assert isinstance(bs, int) + + +def test_lr_candidates_between_min_and_max(tmpdir): + """Test that learning rate candidates are between min_lr and max_lr.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): + super().__init__() + self.save_hyperparameters() + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + + lr_min = 1e-8 + lr_max = 1.0 + lr_finder = trainer.tuner.lr_find( + model, + max_lr=lr_min, + min_lr=lr_max, + num_training=3, + ) + lr_candidates = lr_finder.results["lr"] + assert all([lr_min <= lr <= lr_max for lr in lr_candidates]) + + +def test_lr_finder_ends_before_num_training(tmpdir): + """Tests learning rate finder ends before `num_training` steps.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): + super().__init__() + self.save_hyperparameters() + + def training_step_end(self, outputs): + assert self.global_step < num_training + return outputs + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + num_training = 3 + _ = trainer.tuner.lr_find( + model=model, + num_training=num_training, + )