Removed redundant computations in clip_gradients that slowed down the gradient clipping. (#1523)

Fixes #1522
This commit is contained in:
Jonas-Jaeger 2020-04-19 05:07:15 +02:00 committed by GitHub
parent a22a8142ac
commit e02146943d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -40,7 +40,7 @@ class TrainerTrainingTricksMixin(ABC):
device = parameters[0].device
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
param_norm = p.grad.data.pow(norm_type).sum()
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
eps = EPSILON_FP16 if self.precision == 16 else EPSILON