[Metrics] MetricCollection (#4318)

* docs + precision + recall + f_beta + refactor

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* rebase

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* fixes

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* added missing file

* docs

* docs

* extra import

* add metric collection

* add docs + integration with log_dict

* add test

* update

* update

* more test

* more test

* pep8

* fix doctest

* pep8

* add clone method

* add clone method

* merge-2

* changelog

* kwargs filtering and tests

* pep8

* fix test

* update docs

* Update docs/source/metrics.rst

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>

* fix docs

* fix tests

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix docs

* fix doctest

* fix doctest

* fix doctest

* fix doctest

Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Nicki Skafte 2021-01-08 11:09:07 +01:00 committed by GitHub
parent 06f36092a4
commit 06668c0ddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 384 additions and 14 deletions

View File

@ -33,6 +33,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056)) - Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056))
- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))
- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))
### Changed ### Changed
- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) - Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))

View File

@ -81,6 +81,7 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
.. note:: .. note::
If using metrics in data parallel mode (dp), the metric update/logging should be done If using metrics in data parallel mode (dp), the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation`` in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
or ``test``). This is due to metric states else being destroyed after each forward pass, or ``test``). This is due to metric states else being destroyed after each forward pass,
@ -99,7 +100,6 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.metric(outputs['preds'], outputs['target']) self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric) self.log('metric', self.metric)
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. code-block:: python .. code-block:: python
@ -131,7 +131,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
Metrics contain internal states that keep track of the data seen so far. Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing. Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as It is highly recommended to re-initialize the metric per mode as
shown in the examples above. shown in the examples above. For easy initializing the same metric multiple
times, the ``.clone()`` method can be used:
.. testcode::
def __init__(self):
...
metric = pl.metrics.Accuracy()
self.train_acc = metric.clone()
self.val_acc = metric.clone()
self.test_acc = metric.clone()
.. note:: .. note::
@ -240,6 +250,69 @@ In practise this means that:
val = metric(pred, target) # this value can be backpropagated val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated val = metric.compute() # this value cannot be backpropagated
****************
MetricCollection
****************
In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the `MetricCollection` class may come in handy. It accepts a sequence
of metrics and wraps theses into a single callable metric class, with the same
interface as any other metric.
Example:
.. testcode::
from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
])
print(metric_collection(preds, target))
.. testoutput::
:options: +NORMALIZE_WHITESPACE
{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule
.. code-block:: python
def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
.. note::
`MetricCollection` as default assumes that all the metrics in the collection
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.
.. autoclass:: pytorch_lightning.metrics.MetricCollection
:noindex:
********** **********
Metric API Metric API

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pytorch_lightning.metrics.metric import Metric # noqa: F401 from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
from pytorch_lightning.metrics.classification import ( # noqa: F401 from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy, Accuracy,

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
Callback that performs the allgather operation on the metric state. When `None`, DDP Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None will be used to perform the allgather. default: None
""" """
def __init__( def __init__(
self, self,
compute_on_step: bool = True, compute_on_step: bool = True,
@ -72,6 +74,7 @@ class Metric(nn.Module, ABC):
self.dist_sync_fn = dist_sync_fn self.dist_sync_fn = dist_sync_fn
self._to_sync = True self._to_sync = True
self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update) self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute) self.compute = self._wrap_compute(self.compute)
self._computed = None self._computed = None
@ -120,7 +123,7 @@ class Metric(nn.Module, ABC):
""" """
if ( if (
not isinstance(default, torch.Tensor) not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503 and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503
): ):
raise ValueError( raise ValueError(
@ -208,9 +211,11 @@ class Metric(nn.Module, ABC):
return self._computed return self._computed
dist_sync_fn = self.dist_sync_fn dist_sync_fn = self.dist_sync_fn
if (dist_sync_fn is None if (
and torch.distributed.is_available() dist_sync_fn is None
and torch.distributed.is_initialized()): and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# User provided a bool, so we assume DDP if available # User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors dist_sync_fn = gather_all_tensors
@ -250,6 +255,10 @@ class Metric(nn.Module, ABC):
else: else:
setattr(self, attr, deepcopy(default)) setattr(self, attr, deepcopy(default))
def clone(self):
""" Make a copy of the metric """
return deepcopy(self)
def __getstate__(self): def __getstate__(self):
# ignore update and compute functions for pickling # ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
@ -292,3 +301,101 @@ class Metric(nn.Module, ABC):
current_val = getattr(self, key) current_val = getattr(self, key)
state_dict.update({key: current_val}) state_dict.update({key: current_val})
return state_dict return state_dict
class MetricCollection(nn.ModuleDict):
"""
MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.
Args:
metrics: One of the following
* list or tuple: if metrics are passed in as a list, will use the
metrics class name as key for output dict. Therefore, two metrics
of the same class cannot be chained this way.
* dict: if metrics are passed in as a dict, will use each key in the
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
Example (input as list):
>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([Accuracy(),
... Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro')])
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as dict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> metrics(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
"""
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise ValueError(f'Value {metric} belonging to key {name}'
' is not an instance of `pl.metrics.Metric`')
self[name] = metric
elif isinstance(metrics, (tuple, list)):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
' of `pl.metrics.Metric`')
name = metric.__class__.__name__
if name in self:
raise ValueError(f'Encountered two metrics both named {name}')
self[name] = metric
else:
raise ValueError('Unknown input to MetricCollection.')
def _filter_kwargs(self, metric: Metric, **kwargs):
""" filter kwargs such that they match the update signature of the metric """
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}
def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
m_kwargs = self._filter_kwargs(m, **kwargs)
m.update(*args, **m_kwargs)
def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
def reset(self):
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()
def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)
def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)

View File

@ -6,8 +6,7 @@ import cloudpickle
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from pytorch_lightning.metrics.metric import Metric, MetricCollection
from pytorch_lightning.metrics.metric import Metric
torch.manual_seed(42) torch.manual_seed(42)
@ -17,7 +16,7 @@ class Dummy(Metric):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=None) self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
def update(self): def update(self):
pass pass
@ -166,7 +165,7 @@ def test_forward():
assert a.compute() == 13 assert a.compute() == 13
class ToPickle(Dummy): class DummyMetric1(Dummy):
def update(self, x): def update(self, x):
self.x += x self.x += x
@ -174,9 +173,17 @@ class ToPickle(Dummy):
return self.x return self.x
class DummyMetric2(Dummy):
def update(self, y):
self.x -= y
def compute(self):
return self.x
def test_pickle(tmpdir): def test_pickle(tmpdir):
# doesn't tests for DDP # doesn't tests for DDP
a = ToPickle() a = DummyMetric1()
a.update(1) a.update(1)
metric_pickled = pickle.dumps(a) metric_pickled = pickle.dumps(a)
@ -201,3 +208,130 @@ def test_state_dict(tmpdir):
assert metric.state_dict() == OrderedDict(x=0) assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False) metric.persistent(False)
assert metric.state_dict() == OrderedDict() assert metric.state_dict() == OrderedDict()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer(tmpdir):
metric = DummyMetric1()
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric = metric.to(device='cuda')
assert metric.x.is_cuda
metric = metric.double()
assert metric.x.dtype == torch.float64
metric = metric.half()
assert metric.x.dtype == torch.float16
def test_metric_collection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# Test correct dict structure
assert len(metric_collection) == 2
assert metric_collection['DummyMetric1'] == m1
assert metric_collection['DummyMetric2'] == m2
# Test correct initialization
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not initialized correctly'
# Test every metric gets updated
metric_collection.update(5)
for name, metric in metric_collection.items():
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
# Test compute on each metric
metric_collection.update(-5)
metric_vals = metric_collection.compute()
assert len(metric_vals) == 2
for name, metric_val in metric_vals.items():
assert metric_val == 0, f'Metric {name}.compute not called correctly'
# Test that everything is reset
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not reset correctly'
# Test pickable
metric_pickled = pickle.dumps(metric_collection)
metric_loaded = pickle.loads(metric_pickled)
assert isinstance(metric_loaded, MetricCollection)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer_metriccollection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
for _, metric in metric_collection.items():
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric_collection = metric_collection.to(device='cuda')
for _, metric in metric_collection.items():
assert metric.x.is_cuda
metric_collection = metric_collection.double()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float64
metric_collection = metric_collection.half()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float16
def test_metric_collection_wrong_input(tmpdir):
""" Check that errors are raised on wrong input """
m1 = DummyMetric1()
# Not all input are metrics (list)
with pytest.raises(ValueError):
_ = MetricCollection([m1, 5])
# Not all input are metrics (dict)
with pytest.raises(ValueError):
_ = MetricCollection({'metric1': m1,
'metric2': 5})
# Same metric passed in multiple times
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
_ = MetricCollection([m1, m1])
# Not a list or dict passed in
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
_ = MetricCollection(m1)
def test_metric_collection_args_kwargs(tmpdir):
""" Check that args and kwargs gets passed correctly in metric collection,
Checks both update and forward method
"""
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# args gets passed to all metrics
metric_collection.update(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
_ = metric_collection(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
# kwargs gets only passed to metrics that it matches
metric_collection.update(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20
metric_collection.reset()
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20

View File

@ -1,7 +1,7 @@
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric from pytorch_lightning.metrics import Metric, MetricCollection
from tests.base.boring_model import BoringModel from tests.base.boring_model import BoringModel
@ -17,6 +17,18 @@ class SumMetric(Metric):
return self.x return self.x
class DiffMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, x):
self.x -= x
def compute(self):
return self.x
def test_metric_lightning(tmpdir): def test_metric_lightning(tmpdir):
class TestModel(BoringModel): class TestModel(BoringModel):
def __init__(self): def __init__(self):
@ -125,3 +137,41 @@ def test_scriptable(tmpdir):
output = model(rand_input) output = model(rand_input)
script_output = script_model(rand_input) script_output = script_model(rand_input)
assert torch.allclose(output, script_output) assert torch.allclose(output, script_output)
def test_metric_collection_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = MetricCollection([SumMetric(), DiffMetric()])
self.sum = 0.0
self.diff = 0.0
def training_step(self, batch, batch_idx):
x = batch
metric_vals = self.metric(x.sum())
self.sum += x.sum()
self.diff -= x.sum()
self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
return self.step(x)
def training_epoch_end(self, outputs):
metric_vals = self.metric.compute()
self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
model = TestModel()
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)