From b1d8840fd86221d06751e49f3cf97efc2c8238c4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 28 Jun 2021 21:17:43 +0100 Subject: [PATCH] Fix metric attribute lookup (#8181) * Fix metric attribute lookup * Update CHANGELOG.md * Split tests --- CHANGELOG.md | 2 + pytorch_lightning/core/lightning.py | 4 +- tests/metrics/test_metric_lightning.py | 93 ++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32320fdec6..aa47dfe3c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -333,6 +333,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `log_gpu_memory` metrics not being added to `logging` when nothing else is logged ([#8174](https://github.com/PyTorchLightning/pytorch-lightning/pull/8174)) +- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181)) + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bf05b1f077..ab8263bc8d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -375,7 +375,7 @@ class LightningModule( # compute once self._metric_attributes = { id(module): name - for name, module in self.named_children() if isinstance(module, Metric) + for name, module in self.named_modules() if isinstance(module, Metric) } if not self._metric_attributes: raise MisconfigurationException( @@ -383,7 +383,7 @@ class LightningModule( " You can fix this by setting an attribute for the metric in your `LightningModule`." ) # try to find the passed metric in the LightningModule - metric_attribute = self._metric_attributes.get(id(value)) + metric_attribute = self._metric_attributes.get(id(value), None) if metric_attribute is None: raise MisconfigurationException( "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 8ce7f1050c..5a20bc2475 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,9 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from re import escape + +import pytest import torch +from torch import nn from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric as PLMetric from pytorch_lightning.metrics import MetricCollection +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -192,3 +210,78 @@ def test_metric_collection_lightning_log(tmpdir): logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff) + + +def test_log_metric_no_attributes_raises(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + metric = SumMetric() + self.log("foo", metric) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + with pytest.raises(MisconfigurationException, match="Could not find the `LightningModule` attribute"): + trainer.fit(model) + + +def test_log_metrics_wrong_attributes_raises(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + + self.a_metric = SumMetric() + + def training_step(self, *args): + metric = SumMetric() + self.log("foo", metric) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + with pytest.raises(MisconfigurationException, match=escape("where `name` is one of ['a_metric']")): + trainer.fit(model) + + +def test_log_metric_dict(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.metrics = nn.ModuleDict({'sum': SumMetric(), 'diff': DiffMetric()}) + self.sum = 0.0 + self.diff = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metrics['sum'](x.sum()) + self.metrics['diff'](x.sum()) + self.sum += x.sum() + self.diff -= x.sum() + self.log_dict({f'{k}_step': v for k, v in self.metrics.items()}) + return self.step(x) + + def training_epoch_end(self, outputs): + self.metrics['sum'].compute() + self.metrics['diff'].compute() + self.log_dict({f'{k}_epoch': v for k, v in self.metrics.items()}) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) + assert torch.allclose(torch.tensor(logged["diff_epoch"]), model.diff)