diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 42dd67b724..d18c9dcffd 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -197,7 +197,7 @@ class StochasticWeightAveraging(Callback): self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) if self.swa_start <= trainer.current_epoch <= self.swa_end: - self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) + self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: