Fix `lr_finder` suggesting too high learning rates (#7076)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
d534e53ec4
commit
92af363270
|
@ -304,6 +304,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))
|
- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed `lr_find` trying beyond `num_training` steps and suggesting a too high learning rate ([#7076](https://github.com/PyTorchLightning/pytorch-lightning/pull/7076))
|
||||||
|
|
||||||
|
|
||||||
- Fixed logger creating incorrect version folder in DDP with repeated `Trainer.fit` calls ([#7077](https://github.com/PyTorchLightning/pytorch-lightning/pull/7077))
|
- Fixed logger creating incorrect version folder in DDP with repeated `Trainer.fit` calls ([#7077](https://github.com/PyTorchLightning/pytorch-lightning/pull/7077))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -521,7 +521,7 @@ class TrainLoop:
|
||||||
|
|
||||||
# max steps reached, end training
|
# max steps reached, end training
|
||||||
if (
|
if (
|
||||||
self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1
|
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
|
||||||
and self._accumulated_batches_reached()
|
and self._accumulated_batches_reached()
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
|
@ -180,7 +180,7 @@ def lr_find(
|
||||||
|
|
||||||
# Prompt if we stopped early
|
# Prompt if we stopped early
|
||||||
if trainer.global_step != num_training:
|
if trainer.global_step != num_training:
|
||||||
log.info('LR finder stopped early due to diverging loss.')
|
log.info(f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.')
|
||||||
|
|
||||||
# Transfer results from callback to lr finder object
|
# Transfer results from callback to lr finder object
|
||||||
lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses})
|
lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses})
|
||||||
|
|
|
@ -295,3 +295,45 @@ def test_lr_find_with_bs_scale(tmpdir):
|
||||||
|
|
||||||
assert lr != before_lr
|
assert lr != before_lr
|
||||||
assert isinstance(bs, int)
|
assert isinstance(bs, int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_candidates_between_min_and_max(tmpdir):
|
||||||
|
"""Test that learning rate candidates are between min_lr and max_lr."""
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self, learning_rate=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
trainer = Trainer(default_root_dir=tmpdir)
|
||||||
|
|
||||||
|
lr_min = 1e-8
|
||||||
|
lr_max = 1.0
|
||||||
|
lr_finder = trainer.tuner.lr_find(
|
||||||
|
model,
|
||||||
|
max_lr=lr_min,
|
||||||
|
min_lr=lr_max,
|
||||||
|
num_training=3,
|
||||||
|
)
|
||||||
|
lr_candidates = lr_finder.results["lr"]
|
||||||
|
assert all([lr_min <= lr <= lr_max for lr in lr_candidates])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_finder_ends_before_num_training(tmpdir):
|
||||||
|
"""Tests learning rate finder ends before `num_training` steps."""
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self, learning_rate=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
def training_step_end(self, outputs):
|
||||||
|
assert self.global_step < num_training
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
trainer = Trainer(default_root_dir=tmpdir)
|
||||||
|
num_training = 3
|
||||||
|
_ = trainer.tuner.lr_find(
|
||||||
|
model=model,
|
||||||
|
num_training=num_training,
|
||||||
|
)
|
Loading…
Reference in New Issue