lightning/pytorch_lightning/root_module/grads.py

30 lines
916 B
Python
Raw Normal View History

2019-03-31 01:45:16 +00:00
"""
Module to describe gradients
"""
from torch import nn
2019-03-31 01:45:16 +00:00
2019-08-05 21:57:39 +00:00
2019-03-31 01:45:16 +00:00
class GradInformation(nn.Module):
def grad_norm(self, norm_type):
results = {}
total_norm = 0
for i, p in enumerate(self.parameters()):
if p.requires_grad:
try:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
norm = param_norm ** (1 / norm_type)
2019-08-06 10:08:31 +00:00
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_{}'.format(norm_type, i)] = grad
2019-08-05 21:57:39 +00:00
except Exception:
2019-03-31 01:45:16 +00:00
# this param had no grad
pass
total_norm = total_norm ** (1. / norm_type)
2019-08-06 10:08:31 +00:00
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_total'.format(norm_type)] = grad
2019-03-31 01:45:16 +00:00
return results