2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
Module to describe gradients
|
|
|
|
"""
|
2020-03-12 16:47:23 +00:00
|
|
|
from typing import Dict
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-05 08:52:09 +00:00
|
|
|
from torch import nn
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-05 21:57:39 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
class GradInformation(nn.Module):
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def grad_norm(self, norm_type: float) -> Dict[str, int]:
|
2019-03-31 01:45:16 +00:00
|
|
|
results = {}
|
|
|
|
total_norm = 0
|
2020-01-14 03:54:06 +00:00
|
|
|
for name, p in self.named_parameters():
|
2019-03-31 01:45:16 +00:00
|
|
|
if p.requires_grad:
|
|
|
|
try:
|
|
|
|
param_norm = p.grad.data.norm(norm_type)
|
|
|
|
total_norm += param_norm ** norm_type
|
|
|
|
norm = param_norm ** (1 / norm_type)
|
|
|
|
|
2019-08-06 10:08:31 +00:00
|
|
|
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
|
2020-01-14 03:54:06 +00:00
|
|
|
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-03-31 01:45:16 +00:00
|
|
|
# this param had no grad
|
|
|
|
pass
|
|
|
|
|
|
|
|
total_norm = total_norm ** (1. / norm_type)
|
2019-08-06 10:08:31 +00:00
|
|
|
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
|
|
|
|
results['grad_{}_norm_total'.format(norm_type)] = grad
|
2019-03-31 01:45:16 +00:00
|
|
|
return results
|