Use class name in SWA info message (#8602)

This commit is contained in:
Carlos Mocholí 2021-07-29 09:39:46 +02:00 committed by GitHub
parent ebd2e87752
commit 0dc0472e1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -195,7 +195,10 @@ class StochasticWeightAveraging(Callback):
scheduler_cfg = trainer.lr_schedulers[0]
if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
rank_zero_info(f"Swapping scheduler {scheduler_cfg['scheduler']} for {self._swa_scheduler}")
rank_zero_info(
f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
f" for `{self._swa_scheduler.__class__.__name__}`"
)
trainer.lr_schedulers[0] = default_scheduler_cfg
else:
trainer.lr_schedulers.append(default_scheduler_cfg)

View File

@ -175,7 +175,7 @@ def test_swa_warns(tmpdir, caplog):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, stochastic_weight_avg=True)
with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"):
trainer.fit(model)
assert "Swapping scheduler" in caplog.text
assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
def test_swa_raises():