Refactor `grad_norm` function (#9742)
This commit is contained in:
parent
7f95fd04d7
commit
fb81e738fa
|
@ -35,18 +35,13 @@ def grad_norm(module: Module, norm_type: Union[float, int, str]) -> Dict[str, fl
|
|||
as a single vector.
|
||||
"""
|
||||
norm_type = float(norm_type)
|
||||
|
||||
norms, all_norms = {}, []
|
||||
for name, p in module.named_parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
param_norm = float(p.grad.data.norm(norm_type))
|
||||
norms[f"grad_{norm_type}_norm_{name}"] = round(param_norm, 4)
|
||||
|
||||
all_norms.append(param_norm)
|
||||
|
||||
total_norm = float(torch.tensor(all_norms).norm(norm_type))
|
||||
norms[f"grad_{norm_type}_norm_total"] = round(total_norm, 4)
|
||||
|
||||
norms = {
|
||||
f"grad_{norm_type}_norm_{name}": p.grad.data.norm(norm_type).item()
|
||||
for name, p in module.named_parameters()
|
||||
if p.grad is not None
|
||||
}
|
||||
if norms:
|
||||
total_norm = torch.tensor(list(norms.values())).norm(norm_type).item()
|
||||
norms[f"grad_{norm_type}_norm_total"] = total_norm
|
||||
norms = {k: round(v, 4) for k, v in norms.items()}
|
||||
return norms
|
||||
|
|
|
@ -105,6 +105,6 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps):
|
|||
# logging on n steps + 1 epochs
|
||||
assert len(grad_norm_dicts) == expected + 1
|
||||
# check all metrics derived from steps have the same keys
|
||||
assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1])
|
||||
assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[1:-1])
|
||||
epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0]]
|
||||
assert epoch_end_keys == list(grad_norm_dicts[-1])
|
||||
|
|
Loading…
Reference in New Issue