Fix `lr_finder` suggesting too high learning rates (#7076)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Akihiro Nitta 2021-04-23 19:59:40 +09:00 committed by GitHub
parent d534e53ec4
commit 92af363270
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 2 deletions

View File

@ -304,6 +304,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))
- Fixed `lr_find` trying beyond `num_training` steps and suggesting a too high learning rate ([#7076](https://github.com/PyTorchLightning/pytorch-lightning/pull/7076))
- Fixed logger creating incorrect version folder in DDP with repeated `Trainer.fit` calls ([#7077](https://github.com/PyTorchLightning/pytorch-lightning/pull/7077))

View File

@ -521,7 +521,7 @@ class TrainLoop:
# max steps reached, end training
if (
self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
and self._accumulated_batches_reached()
):
break

View File

@ -180,7 +180,7 @@ def lr_find(
# Prompt if we stopped early
if trainer.global_step != num_training:
log.info('LR finder stopped early due to diverging loss.')
log.info(f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.')
# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses})

View File

@ -295,3 +295,45 @@ def test_lr_find_with_bs_scale(tmpdir):
assert lr != before_lr
assert isinstance(bs, int)
def test_lr_candidates_between_min_and_max(tmpdir):
"""Test that learning rate candidates are between min_lr and max_lr."""
class TestModel(BoringModel):
def __init__(self, learning_rate=0.1):
super().__init__()
self.save_hyperparameters()
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)
lr_min = 1e-8
lr_max = 1.0
lr_finder = trainer.tuner.lr_find(
model,
max_lr=lr_min,
min_lr=lr_max,
num_training=3,
)
lr_candidates = lr_finder.results["lr"]
assert all([lr_min <= lr <= lr_max for lr in lr_candidates])
def test_lr_finder_ends_before_num_training(tmpdir):
"""Tests learning rate finder ends before `num_training` steps."""
class TestModel(BoringModel):
def __init__(self, learning_rate=0.1):
super().__init__()
self.save_hyperparameters()
def training_step_end(self, outputs):
assert self.global_step < num_training
return outputs
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)
num_training = 3
_ = trainer.tuner.lr_find(
model=model,
num_training=num_training,
)