Use class name in SWA info message (#8602)
This commit is contained in:
parent
ebd2e87752
commit
0dc0472e1f
|
@ -195,7 +195,10 @@ class StochasticWeightAveraging(Callback):
|
|||
scheduler_cfg = trainer.lr_schedulers[0]
|
||||
if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
|
||||
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
|
||||
rank_zero_info(f"Swapping scheduler {scheduler_cfg['scheduler']} for {self._swa_scheduler}")
|
||||
rank_zero_info(
|
||||
f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
|
||||
f" for `{self._swa_scheduler.__class__.__name__}`"
|
||||
)
|
||||
trainer.lr_schedulers[0] = default_scheduler_cfg
|
||||
else:
|
||||
trainer.lr_schedulers.append(default_scheduler_cfg)
|
||||
|
|
|
@ -175,7 +175,7 @@ def test_swa_warns(tmpdir, caplog):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, stochastic_weight_avg=True)
|
||||
with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"):
|
||||
trainer.fit(model)
|
||||
assert "Swapping scheduler" in caplog.text
|
||||
assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
|
||||
|
||||
|
||||
def test_swa_raises():
|
||||
|
|
Loading…
Reference in New Issue