diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index f5523210ef..3364c9d305 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -40,7 +40,7 @@ class TrainerTrainingTricksMixin(ABC): device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters: - param_norm = p.grad.data.norm(norm_type) ** norm_type + param_norm = p.grad.data.pow(norm_type).sum() total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) eps = EPSILON_FP16 if self.precision == 16 else EPSILON