From fb81e738fada47068741c02ed5ca822681587a9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 30 Sep 2021 04:54:08 +0200 Subject: [PATCH] Refactor `grad_norm` function (#9742) --- pytorch_lightning/utilities/grads.py | 23 +++++++++-------------- tests/models/test_grad_norm.py | 2 +- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/grads.py b/pytorch_lightning/utilities/grads.py index c1dfb2277b..93e9e832f0 100644 --- a/pytorch_lightning/utilities/grads.py +++ b/pytorch_lightning/utilities/grads.py @@ -35,18 +35,13 @@ def grad_norm(module: Module, norm_type: Union[float, int, str]) -> Dict[str, fl as a single vector. """ norm_type = float(norm_type) - - norms, all_norms = {}, [] - for name, p in module.named_parameters(): - if p.grad is None: - continue - - param_norm = float(p.grad.data.norm(norm_type)) - norms[f"grad_{norm_type}_norm_{name}"] = round(param_norm, 4) - - all_norms.append(param_norm) - - total_norm = float(torch.tensor(all_norms).norm(norm_type)) - norms[f"grad_{norm_type}_norm_total"] = round(total_norm, 4) - + norms = { + f"grad_{norm_type}_norm_{name}": p.grad.data.norm(norm_type).item() + for name, p in module.named_parameters() + if p.grad is not None + } + if norms: + total_norm = torch.tensor(list(norms.values())).norm(norm_type).item() + norms[f"grad_{norm_type}_norm_total"] = total_norm + norms = {k: round(v, 4) for k, v in norms.items()} return norms diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index fb0c0d4a8d..ae8494b6c6 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -105,6 +105,6 @@ def test_grad_tracking_interval(tmpdir, log_every_n_steps): # logging on n steps + 1 epochs assert len(grad_norm_dicts) == expected + 1 # check all metrics derived from steps have the same keys - assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[:-1]) + assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts[1:-1]) epoch_end_keys = [k.replace("step", "epoch") for k in grad_norm_dicts[0]] assert epoch_end_keys == list(grad_norm_dicts[-1])