Refactor `grad_norm` function (#9742)

This commit is contained in:
Carlos Mocholí 2021-09-30 04:54:08 +02:00 committed by GitHub
parent 7f95fd04d7
commit fb81e738fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 15 deletions

View File

@ -35,18 +35,13 @@ def grad_norm(module: Module, norm_type: Union[float, int, str]) -> Dict[str, fl
as a single vector.
"""
norm_type = float(norm_type)
norms, all_norms = {}, []
for name, p in module.named_parameters():
if p.grad is None:
continue
param_norm = float(p.grad.data.norm(norm_type))
norms[f"grad_{norm_type}_norm_{name}"] = round(param_norm, 4)
all_norms.append(param_norm)
total_norm = float(torch.tensor(all_norms).norm(norm_type))
norms[f"grad_{norm_type}_norm_total"] = round(total_norm, 4)
norms = {
f"grad_{norm_type}_norm_{name}": p.grad.data.norm(norm_type).item()
for name, p in module.named_parameters()
if p.grad is not None
}
if norms:
total_norm = torch.tensor(list(norms.values())).norm(norm_type).item()
norms[f"grad_{norm_type}_norm_total"] = total_norm
norms = {k: round(v, 4) for k, v in norms.items()}
return norms

View File

@ -105,6 +105,6 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps):
# logging on n steps + 1 epochs
assert len(grad_norm_dicts) == expected + 1
# check all metrics derived from steps have the same keys
assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1])
assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[1:-1])
epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0]]
assert epoch_end_keys == list(grad_norm_dicts[-1])