[Metrics] MetricCollection (#4318)
* docs + precision + recall + f_beta + refactor Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * rebase Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * fixes Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * added missing file * docs * docs * extra import * add metric collection * add docs + integration with log_dict * add test * update * update * more test * more test * pep8 * fix doctest * pep8 * add clone method * add clone method * merge-2 * changelog * kwargs filtering and tests * pep8 * fix test * update docs * Update docs/source/metrics.rst Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> * fix docs * fix tests * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix docs * fix doctest * fix doctest * fix doctest * fix doctest Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai> Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
06f36092a4
commit
06668c0ddf
|
@ -33,6 +33,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056))
|
- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056))
|
||||||
|
|
||||||
|
|
||||||
|
- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))
|
||||||
|
|
||||||
|
|
||||||
|
- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))
|
- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))
|
||||||
|
|
|
@ -81,6 +81,7 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
||||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
If using metrics in data parallel mode (dp), the metric update/logging should be done
|
If using metrics in data parallel mode (dp), the metric update/logging should be done
|
||||||
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
|
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
|
||||||
or ``test``). This is due to metric states else being destroyed after each forward pass,
|
or ``test``). This is due to metric states else being destroyed after each forward pass,
|
||||||
|
@ -99,7 +100,6 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
||||||
self.metric(outputs['preds'], outputs['target'])
|
self.metric(outputs['preds'], outputs['target'])
|
||||||
self.log('metric', self.metric)
|
self.log('metric', self.metric)
|
||||||
|
|
||||||
|
|
||||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -131,7 +131,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
|
||||||
Metrics contain internal states that keep track of the data seen so far.
|
Metrics contain internal states that keep track of the data seen so far.
|
||||||
Do not mix metric states across training, validation and testing.
|
Do not mix metric states across training, validation and testing.
|
||||||
It is highly recommended to re-initialize the metric per mode as
|
It is highly recommended to re-initialize the metric per mode as
|
||||||
shown in the examples above.
|
shown in the examples above. For easy initializing the same metric multiple
|
||||||
|
times, the ``.clone()`` method can be used:
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
metric = pl.metrics.Accuracy()
|
||||||
|
self.train_acc = metric.clone()
|
||||||
|
self.val_acc = metric.clone()
|
||||||
|
self.test_acc = metric.clone()
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
@ -240,6 +250,69 @@ In practise this means that:
|
||||||
val = metric(pred, target) # this value can be backpropagated
|
val = metric(pred, target) # this value can be backpropagated
|
||||||
val = metric.compute() # this value cannot be backpropagated
|
val = metric.compute() # this value cannot be backpropagated
|
||||||
|
|
||||||
|
****************
|
||||||
|
MetricCollection
|
||||||
|
****************
|
||||||
|
|
||||||
|
In many cases it is beneficial to evaluate the model output by multiple metrics.
|
||||||
|
In this case the `MetricCollection` class may come in handy. It accepts a sequence
|
||||||
|
of metrics and wraps theses into a single callable metric class, with the same
|
||||||
|
interface as any other metric.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
|
||||||
|
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
|
||||||
|
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
|
||||||
|
metric_collection = MetricCollection([
|
||||||
|
Accuracy(),
|
||||||
|
Precision(num_classes=3, average='macro'),
|
||||||
|
Recall(num_classes=3, average='macro')
|
||||||
|
])
|
||||||
|
print(metric_collection(preds, target))
|
||||||
|
|
||||||
|
.. testoutput::
|
||||||
|
:options: +NORMALIZE_WHITESPACE
|
||||||
|
|
||||||
|
{'Accuracy': tensor(0.1250),
|
||||||
|
'Precision': tensor(0.0667),
|
||||||
|
'Recall': tensor(0.1111)}
|
||||||
|
|
||||||
|
Similarly it can also reduce the amount of code required to log multiple metrics
|
||||||
|
inside your LightningModule
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
metrics = pl.metrics.MetricCollection(...)
|
||||||
|
self.train_metrics = metrics.clone()
|
||||||
|
self.valid_metrics = metrics.clone()
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
logits = self(x)
|
||||||
|
...
|
||||||
|
self.train_metrics(logits, y)
|
||||||
|
# use log_dict instead of log
|
||||||
|
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
logits = self(x)
|
||||||
|
...
|
||||||
|
self.valid_metrics(logits, y)
|
||||||
|
# use log_dict instead of log
|
||||||
|
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
`MetricCollection` as default assumes that all the metrics in the collection
|
||||||
|
have the same call signature. If this is not the case, input that should be
|
||||||
|
given to different metrics can given as keyword arguments to the collection.
|
||||||
|
|
||||||
|
.. autoclass:: pytorch_lightning.metrics.MetricCollection
|
||||||
|
:noindex:
|
||||||
|
|
||||||
**********
|
**********
|
||||||
Metric API
|
Metric API
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from pytorch_lightning.metrics.metric import Metric # noqa: F401
|
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
|
||||||
|
|
||||||
from pytorch_lightning.metrics.classification import ( # noqa: F401
|
from pytorch_lightning.metrics.classification import ( # noqa: F401
|
||||||
Accuracy,
|
Accuracy,
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
|
||||||
Callback that performs the allgather operation on the metric state. When `None`, DDP
|
Callback that performs the allgather operation on the metric state. When `None`, DDP
|
||||||
will be used to perform the allgather. default: None
|
will be used to perform the allgather. default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
compute_on_step: bool = True,
|
compute_on_step: bool = True,
|
||||||
|
@ -72,6 +74,7 @@ class Metric(nn.Module, ABC):
|
||||||
self.dist_sync_fn = dist_sync_fn
|
self.dist_sync_fn = dist_sync_fn
|
||||||
self._to_sync = True
|
self._to_sync = True
|
||||||
|
|
||||||
|
self._update_signature = inspect.signature(self.update)
|
||||||
self.update = self._wrap_update(self.update)
|
self.update = self._wrap_update(self.update)
|
||||||
self.compute = self._wrap_compute(self.compute)
|
self.compute = self._wrap_compute(self.compute)
|
||||||
self._computed = None
|
self._computed = None
|
||||||
|
@ -120,7 +123,7 @@ class Metric(nn.Module, ABC):
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
not isinstance(default, torch.Tensor)
|
not isinstance(default, torch.Tensor)
|
||||||
and not isinstance(default, list) # noqa: W503
|
and not isinstance(default, list) # noqa: W503
|
||||||
or (isinstance(default, list) and len(default) != 0) # noqa: W503
|
or (isinstance(default, list) and len(default) != 0) # noqa: W503
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -208,9 +211,11 @@ class Metric(nn.Module, ABC):
|
||||||
return self._computed
|
return self._computed
|
||||||
|
|
||||||
dist_sync_fn = self.dist_sync_fn
|
dist_sync_fn = self.dist_sync_fn
|
||||||
if (dist_sync_fn is None
|
if (
|
||||||
and torch.distributed.is_available()
|
dist_sync_fn is None
|
||||||
and torch.distributed.is_initialized()):
|
and torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
):
|
||||||
# User provided a bool, so we assume DDP if available
|
# User provided a bool, so we assume DDP if available
|
||||||
dist_sync_fn = gather_all_tensors
|
dist_sync_fn = gather_all_tensors
|
||||||
|
|
||||||
|
@ -250,6 +255,10 @@ class Metric(nn.Module, ABC):
|
||||||
else:
|
else:
|
||||||
setattr(self, attr, deepcopy(default))
|
setattr(self, attr, deepcopy(default))
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
""" Make a copy of the metric """
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# ignore update and compute functions for pickling
|
# ignore update and compute functions for pickling
|
||||||
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
|
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
|
||||||
|
@ -292,3 +301,101 @@ class Metric(nn.Module, ABC):
|
||||||
current_val = getattr(self, key)
|
current_val = getattr(self, key)
|
||||||
state_dict.update({key: current_val})
|
state_dict.update({key: current_val})
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MetricCollection(nn.ModuleDict):
|
||||||
|
"""
|
||||||
|
MetricCollection class can be used to chain metrics that have the same
|
||||||
|
call pattern into one single class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: One of the following
|
||||||
|
|
||||||
|
* list or tuple: if metrics are passed in as a list, will use the
|
||||||
|
metrics class name as key for output dict. Therefore, two metrics
|
||||||
|
of the same class cannot be chained this way.
|
||||||
|
|
||||||
|
* dict: if metrics are passed in as a dict, will use each key in the
|
||||||
|
dict as key for output dict. Use this format if you want to chain
|
||||||
|
together multiple of the same metric with different parameters.
|
||||||
|
|
||||||
|
Example (input as list):
|
||||||
|
|
||||||
|
>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
|
||||||
|
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
|
||||||
|
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
|
||||||
|
>>> metrics = MetricCollection([Accuracy(),
|
||||||
|
... Precision(num_classes=3, average='macro'),
|
||||||
|
... Recall(num_classes=3, average='macro')])
|
||||||
|
>>> metrics(preds, target)
|
||||||
|
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
|
||||||
|
|
||||||
|
Example (input as dict):
|
||||||
|
|
||||||
|
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
|
||||||
|
... 'macro_recall': Recall(num_classes=3, average='macro')})
|
||||||
|
>>> metrics(preds, target)
|
||||||
|
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(metrics, dict):
|
||||||
|
# Check all values are metrics
|
||||||
|
for name, metric in metrics.items():
|
||||||
|
if not isinstance(metric, Metric):
|
||||||
|
raise ValueError(f'Value {metric} belonging to key {name}'
|
||||||
|
' is not an instance of `pl.metrics.Metric`')
|
||||||
|
self[name] = metric
|
||||||
|
elif isinstance(metrics, (tuple, list)):
|
||||||
|
for metric in metrics:
|
||||||
|
if not isinstance(metric, Metric):
|
||||||
|
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
|
||||||
|
' of `pl.metrics.Metric`')
|
||||||
|
name = metric.__class__.__name__
|
||||||
|
if name in self:
|
||||||
|
raise ValueError(f'Encountered two metrics both named {name}')
|
||||||
|
self[name] = metric
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown input to MetricCollection.')
|
||||||
|
|
||||||
|
def _filter_kwargs(self, metric: Metric, **kwargs):
|
||||||
|
""" filter kwargs such that they match the update signature of the metric """
|
||||||
|
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
|
||||||
|
"""
|
||||||
|
Iteratively call forward for each metric. Positional arguments (args) will
|
||||||
|
be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||||
|
will be filtered based on the signature of the individual metric.
|
||||||
|
"""
|
||||||
|
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs): # pylint: disable=E0202
|
||||||
|
"""
|
||||||
|
Iteratively call update for each metric. Positional arguments (args) will
|
||||||
|
be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||||
|
will be filtered based on the signature of the individual metric.
|
||||||
|
"""
|
||||||
|
for _, m in self.items():
|
||||||
|
m_kwargs = self._filter_kwargs(m, **kwargs)
|
||||||
|
m.update(*args, **m_kwargs)
|
||||||
|
|
||||||
|
def compute(self) -> Dict[str, Any]:
|
||||||
|
return {k: m.compute() for k, m in self.items()}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
""" Iteratively call reset for each metric """
|
||||||
|
for _, m in self.items():
|
||||||
|
m.reset()
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
""" Make a copy of the metric collection """
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
|
def persistent(self, mode: bool = True):
|
||||||
|
""" Method for post-init to change if metric states should be saved to
|
||||||
|
its state_dict
|
||||||
|
"""
|
||||||
|
for _, m in self.items():
|
||||||
|
m.persistent(mode)
|
||||||
|
|
|
@ -6,8 +6,7 @@ import cloudpickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch_lightning.metrics.metric import Metric, MetricCollection
|
||||||
from pytorch_lightning.metrics.metric import Metric
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
@ -17,7 +16,7 @@ class Dummy(Metric):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
pass
|
pass
|
||||||
|
@ -166,7 +165,7 @@ def test_forward():
|
||||||
assert a.compute() == 13
|
assert a.compute() == 13
|
||||||
|
|
||||||
|
|
||||||
class ToPickle(Dummy):
|
class DummyMetric1(Dummy):
|
||||||
def update(self, x):
|
def update(self, x):
|
||||||
self.x += x
|
self.x += x
|
||||||
|
|
||||||
|
@ -174,9 +173,17 @@ class ToPickle(Dummy):
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMetric2(Dummy):
|
||||||
|
def update(self, y):
|
||||||
|
self.x -= y
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
def test_pickle(tmpdir):
|
def test_pickle(tmpdir):
|
||||||
# doesn't tests for DDP
|
# doesn't tests for DDP
|
||||||
a = ToPickle()
|
a = DummyMetric1()
|
||||||
a.update(1)
|
a.update(1)
|
||||||
|
|
||||||
metric_pickled = pickle.dumps(a)
|
metric_pickled = pickle.dumps(a)
|
||||||
|
@ -201,3 +208,130 @@ def test_state_dict(tmpdir):
|
||||||
assert metric.state_dict() == OrderedDict(x=0)
|
assert metric.state_dict() == OrderedDict(x=0)
|
||||||
metric.persistent(False)
|
metric.persistent(False)
|
||||||
assert metric.state_dict() == OrderedDict()
|
assert metric.state_dict() == OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
||||||
|
def test_device_and_dtype_transfer(tmpdir):
|
||||||
|
metric = DummyMetric1()
|
||||||
|
assert metric.x.is_cuda is False
|
||||||
|
assert metric.x.dtype == torch.float32
|
||||||
|
|
||||||
|
metric = metric.to(device='cuda')
|
||||||
|
assert metric.x.is_cuda
|
||||||
|
|
||||||
|
metric = metric.double()
|
||||||
|
assert metric.x.dtype == torch.float64
|
||||||
|
|
||||||
|
metric = metric.half()
|
||||||
|
assert metric.x.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_collection(tmpdir):
|
||||||
|
m1 = DummyMetric1()
|
||||||
|
m2 = DummyMetric2()
|
||||||
|
|
||||||
|
metric_collection = MetricCollection([m1, m2])
|
||||||
|
|
||||||
|
# Test correct dict structure
|
||||||
|
assert len(metric_collection) == 2
|
||||||
|
assert metric_collection['DummyMetric1'] == m1
|
||||||
|
assert metric_collection['DummyMetric2'] == m2
|
||||||
|
|
||||||
|
# Test correct initialization
|
||||||
|
for name, metric in metric_collection.items():
|
||||||
|
assert metric.x == 0, f'Metric {name} not initialized correctly'
|
||||||
|
|
||||||
|
# Test every metric gets updated
|
||||||
|
metric_collection.update(5)
|
||||||
|
for name, metric in metric_collection.items():
|
||||||
|
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
|
||||||
|
|
||||||
|
# Test compute on each metric
|
||||||
|
metric_collection.update(-5)
|
||||||
|
metric_vals = metric_collection.compute()
|
||||||
|
assert len(metric_vals) == 2
|
||||||
|
for name, metric_val in metric_vals.items():
|
||||||
|
assert metric_val == 0, f'Metric {name}.compute not called correctly'
|
||||||
|
|
||||||
|
# Test that everything is reset
|
||||||
|
for name, metric in metric_collection.items():
|
||||||
|
assert metric.x == 0, f'Metric {name} not reset correctly'
|
||||||
|
|
||||||
|
# Test pickable
|
||||||
|
metric_pickled = pickle.dumps(metric_collection)
|
||||||
|
metric_loaded = pickle.loads(metric_pickled)
|
||||||
|
assert isinstance(metric_loaded, MetricCollection)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
||||||
|
def test_device_and_dtype_transfer_metriccollection(tmpdir):
|
||||||
|
m1 = DummyMetric1()
|
||||||
|
m2 = DummyMetric2()
|
||||||
|
|
||||||
|
metric_collection = MetricCollection([m1, m2])
|
||||||
|
for _, metric in metric_collection.items():
|
||||||
|
assert metric.x.is_cuda is False
|
||||||
|
assert metric.x.dtype == torch.float32
|
||||||
|
|
||||||
|
metric_collection = metric_collection.to(device='cuda')
|
||||||
|
for _, metric in metric_collection.items():
|
||||||
|
assert metric.x.is_cuda
|
||||||
|
|
||||||
|
metric_collection = metric_collection.double()
|
||||||
|
for _, metric in metric_collection.items():
|
||||||
|
assert metric.x.dtype == torch.float64
|
||||||
|
|
||||||
|
metric_collection = metric_collection.half()
|
||||||
|
for _, metric in metric_collection.items():
|
||||||
|
assert metric.x.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_collection_wrong_input(tmpdir):
|
||||||
|
""" Check that errors are raised on wrong input """
|
||||||
|
m1 = DummyMetric1()
|
||||||
|
|
||||||
|
# Not all input are metrics (list)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = MetricCollection([m1, 5])
|
||||||
|
|
||||||
|
# Not all input are metrics (dict)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = MetricCollection({'metric1': m1,
|
||||||
|
'metric2': 5})
|
||||||
|
|
||||||
|
# Same metric passed in multiple times
|
||||||
|
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
|
||||||
|
_ = MetricCollection([m1, m1])
|
||||||
|
|
||||||
|
# Not a list or dict passed in
|
||||||
|
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
|
||||||
|
_ = MetricCollection(m1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_collection_args_kwargs(tmpdir):
|
||||||
|
""" Check that args and kwargs gets passed correctly in metric collection,
|
||||||
|
Checks both update and forward method
|
||||||
|
"""
|
||||||
|
m1 = DummyMetric1()
|
||||||
|
m2 = DummyMetric2()
|
||||||
|
|
||||||
|
metric_collection = MetricCollection([m1, m2])
|
||||||
|
|
||||||
|
# args gets passed to all metrics
|
||||||
|
metric_collection.update(5)
|
||||||
|
assert metric_collection['DummyMetric1'].x == 5
|
||||||
|
assert metric_collection['DummyMetric2'].x == -5
|
||||||
|
metric_collection.reset()
|
||||||
|
_ = metric_collection(5)
|
||||||
|
assert metric_collection['DummyMetric1'].x == 5
|
||||||
|
assert metric_collection['DummyMetric2'].x == -5
|
||||||
|
metric_collection.reset()
|
||||||
|
|
||||||
|
# kwargs gets only passed to metrics that it matches
|
||||||
|
metric_collection.update(x=10, y=20)
|
||||||
|
assert metric_collection['DummyMetric1'].x == 10
|
||||||
|
assert metric_collection['DummyMetric2'].x == -20
|
||||||
|
metric_collection.reset()
|
||||||
|
_ = metric_collection(x=10, y=20)
|
||||||
|
assert metric_collection['DummyMetric1'].x == 10
|
||||||
|
assert metric_collection['DummyMetric2'].x == -20
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.metrics import Metric
|
from pytorch_lightning.metrics import Metric, MetricCollection
|
||||||
from tests.base.boring_model import BoringModel
|
from tests.base.boring_model import BoringModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,18 @@ class SumMetric(Metric):
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
|
class DiffMetric(Metric):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||||
|
|
||||||
|
def update(self, x):
|
||||||
|
self.x -= x
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
def test_metric_lightning(tmpdir):
|
def test_metric_lightning(tmpdir):
|
||||||
class TestModel(BoringModel):
|
class TestModel(BoringModel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -125,3 +137,41 @@ def test_scriptable(tmpdir):
|
||||||
output = model(rand_input)
|
output = model(rand_input)
|
||||||
script_output = script_model(rand_input)
|
script_output = script_model(rand_input)
|
||||||
assert torch.allclose(output, script_output)
|
assert torch.allclose(output, script_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_collection_lightning_log(tmpdir):
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.metric = MetricCollection([SumMetric(), DiffMetric()])
|
||||||
|
self.sum = 0.0
|
||||||
|
self.diff = 0.0
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x = batch
|
||||||
|
metric_vals = self.metric(x.sum())
|
||||||
|
self.sum += x.sum()
|
||||||
|
self.diff -= x.sum()
|
||||||
|
self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
|
||||||
|
return self.step(x)
|
||||||
|
|
||||||
|
def training_epoch_end(self, outputs):
|
||||||
|
metric_vals = self.metric.compute()
|
||||||
|
self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=2,
|
||||||
|
limit_val_batches=2,
|
||||||
|
max_epochs=1,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
logged = trainer.logged_metrics
|
||||||
|
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
|
||||||
|
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
|
||||||
|
|
Loading…
Reference in New Issue