mirror of
https://github.com/Lightning-AI/lightning.git
synced 2025-02-26 20:25:16 +00:00
Co-authored-by: louie.kim <louie.kim@kakaocorp.comlouie.kim@kakaocorp.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
e3820da28a
commit
24de29974c
@ -197,7 +197,7 @@ class StochasticWeightAveraging(Callback):
|
|||||||
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
|
||||||
|
|
||||||
if self.swa_start <= trainer.current_epoch <= self.swa_end:
|
if self.swa_start <= trainer.current_epoch <= self.swa_end:
|
||||||
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)
|
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
|
||||||
|
|
||||||
# Note: No > here in case the callback is saved with the model and training continues
|
# Note: No > here in case the callback is saved with the model and training continues
|
||||||
if trainer.current_epoch == self.swa_end + 1:
|
if trainer.current_epoch == self.swa_end + 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user