Fix the condition for calling update_learning_rates (#7032)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
502adbced3
commit
20f63377f8
|
@ -374,6 +374,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878))
|
||||
|
||||
|
||||
- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -478,7 +478,6 @@ class TrainLoop:
|
|||
|
||||
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
||||
dataloader_idx = 0
|
||||
val_loop_called = False
|
||||
|
||||
batch_idx = None
|
||||
is_last_batch = None
|
||||
|
@ -519,7 +518,6 @@ class TrainLoop:
|
|||
self.trainer.validating = True
|
||||
self.trainer._run_evaluation()
|
||||
self.trainer.training = True
|
||||
val_loop_called = True
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
|
@ -568,7 +566,7 @@ class TrainLoop:
|
|||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
|
||||
# update epoch level lr_schedulers if no val loop outside train loop is triggered
|
||||
if (val_loop_called and not should_check_val) or should_train_only:
|
||||
if not should_check_val or should_train_only:
|
||||
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
|
||||
|
||||
if should_train_only:
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import optim
|
||||
|
@ -577,21 +579,21 @@ def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
# Adagrad creates state tensors immediately, model is not yet on GPU.
|
||||
return optim.Adagrad(self.parameters())
|
||||
|
||||
def on_train_start(self, *args, **kwargs):
|
||||
opt = self.optimizers()
|
||||
_, state = next(iter(opt.state.items()))
|
||||
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, special=True)
|
||||
def test_optimizer_state_on_device(tmpdir):
|
||||
""" Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
# Adagrad creates state tensors immediately, model is not yet on GPU.
|
||||
return optim.Adagrad(self.parameters())
|
||||
|
||||
def on_train_start(self, *args, **kwargs):
|
||||
opt = self.optimizers()
|
||||
_, state = next(iter(opt.state.items()))
|
||||
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -600,3 +602,21 @@ def test_optimizer_state_on_device(tmpdir):
|
|||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("check_val_every_n_epoch", [1, 2])
|
||||
@mock.patch("torch.optim.lr_scheduler.StepLR.step")
|
||||
def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch, tmpdir):
|
||||
epochs = 4
|
||||
expected_steps = epochs + 1 # every LRScheduler gets called once at init
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
check_val_every_n_epoch=check_val_every_n_epoch,
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert mocked_sched.call_count == expected_steps
|
||||
|
|
Loading…
Reference in New Issue