Fix metric attribute lookup (#8181)
* Fix metric attribute lookup * Update CHANGELOG.md * Split tests
This commit is contained in:
parent
bf54ac1cad
commit
b1d8840fd8
|
@ -333,6 +333,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `log_gpu_memory` metrics not being added to `logging` when nothing else is logged ([#8174](https://github.com/PyTorchLightning/pytorch-lightning/pull/8174))
|
||||
|
||||
|
||||
- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181))
|
||||
|
||||
## [1.3.7] - 2021-06-22
|
||||
|
||||
- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
|
||||
|
|
|
@ -375,7 +375,7 @@ class LightningModule(
|
|||
# compute once
|
||||
self._metric_attributes = {
|
||||
id(module): name
|
||||
for name, module in self.named_children() if isinstance(module, Metric)
|
||||
for name, module in self.named_modules() if isinstance(module, Metric)
|
||||
}
|
||||
if not self._metric_attributes:
|
||||
raise MisconfigurationException(
|
||||
|
@ -383,7 +383,7 @@ class LightningModule(
|
|||
" You can fix this by setting an attribute for the metric in your `LightningModule`."
|
||||
)
|
||||
# try to find the passed metric in the LightningModule
|
||||
metric_attribute = self._metric_attributes.get(id(value))
|
||||
metric_attribute = self._metric_attributes.get(id(value), None)
|
||||
if metric_attribute is None:
|
||||
raise MisconfigurationException(
|
||||
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
|
||||
|
|
|
@ -1,9 +1,27 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from re import escape
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics import Metric as TMetric
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.metrics import Metric as PLMetric
|
||||
from pytorch_lightning.metrics import MetricCollection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -192,3 +210,78 @@ def test_metric_collection_lightning_log(tmpdir):
|
|||
logged = trainer.logged_metrics
|
||||
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
|
||||
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
|
||||
|
||||
|
||||
def test_log_metric_no_attributes_raises(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, *args):
|
||||
metric = SumMetric()
|
||||
self.log("foo", metric)
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
model = TestModel()
|
||||
with pytest.raises(MisconfigurationException, match="Could not find the `LightningModule` attribute"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_log_metrics_wrong_attributes_raises(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.a_metric = SumMetric()
|
||||
|
||||
def training_step(self, *args):
|
||||
metric = SumMetric()
|
||||
self.log("foo", metric)
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
model = TestModel()
|
||||
with pytest.raises(MisconfigurationException, match=escape("where `name` is one of ['a_metric']")):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_log_metric_dict(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metrics = nn.ModuleDict({'sum': SumMetric(), 'diff': DiffMetric()})
|
||||
self.sum = 0.0
|
||||
self.diff = 0.0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = batch
|
||||
self.metrics['sum'](x.sum())
|
||||
self.metrics['diff'](x.sum())
|
||||
self.sum += x.sum()
|
||||
self.diff -= x.sum()
|
||||
self.log_dict({f'{k}_step': v for k, v in self.metrics.items()})
|
||||
return self.step(x)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
self.metrics['sum'].compute()
|
||||
self.metrics['diff'].compute()
|
||||
self.log_dict({f'{k}_epoch': v for k, v in self.metrics.items()})
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
logged = trainer.logged_metrics
|
||||
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
|
||||
assert torch.allclose(torch.tensor(logged["diff_epoch"]), model.diff)
|
||||
|
|
Loading…
Reference in New Issue