integrate metrics API with self.log (#3961)
* metrics integration into self.log Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * ddp and regualr test for self.log + metrics Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * pep8 * fix log tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
parent
aa95addff2
commit
6f1a2ce517
|
@ -25,12 +25,6 @@ logic present in ``.compute()`` is applied to state information from all process
|
|||
|
||||
The example below shows how to use a metric in your ``LightningModule``:
|
||||
|
||||
.. note::
|
||||
|
||||
For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
|
||||
This has been shown in the example below. For v1.0 release, we will integrate metrics
|
||||
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self):
|
||||
|
@ -49,6 +43,41 @@ The example below shows how to use a metric in your ``LightningModule``:
|
|||
self.log('train_acc_epoch', self.accuracy.compute())
|
||||
|
||||
|
||||
``Metric`` objects can also be directly logged, in which case Lightning will log
|
||||
the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``.
|
||||
If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling
|
||||
``.compute()``.
|
||||
|
||||
.. note::
|
||||
``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx``
|
||||
flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class
|
||||
contains its own distributed synchronization logic.
|
||||
|
||||
This however is only true for metrics that inherit the base class ``Metric``,
|
||||
and thus the functional metric API provides no support for in-built distributed synchronization
|
||||
or reduction functions.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
self.train_acc = pl.metrics.Accuracy()
|
||||
self.valid_acc = pl.metrics.Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.train_acc(logits, y)
|
||||
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.valid_acc(logits, y)
|
||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||
|
||||
|
||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
@ -21,7 +21,7 @@ from torch import Tensor
|
|||
import os
|
||||
|
||||
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
|
||||
class Result(Dict):
|
||||
def __init__(
|
||||
|
@ -251,12 +251,16 @@ class Result(Dict):
|
|||
continue
|
||||
|
||||
if options['logger'] and options['on_step']:
|
||||
result[k] = self[k]
|
||||
if isinstance(self[k], Metric):
|
||||
result[k] = self[k]._forward_cache
|
||||
else:
|
||||
result[k] = self[k]
|
||||
|
||||
return result
|
||||
|
||||
def get_epoch_log_metrics(self) -> dict:
|
||||
"""
|
||||
Gets the metrics to log at the end of the batch step
|
||||
Gets the metrics to log at the end of epoch
|
||||
"""
|
||||
result = {}
|
||||
|
||||
|
@ -264,13 +268,22 @@ class Result(Dict):
|
|||
for k, options in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
||||
if options['logger'] and options['on_epoch']:
|
||||
result[k] = self[k]
|
||||
if isinstance(self[k], Metric):
|
||||
result[k] = self[k].compute()
|
||||
else:
|
||||
result[k] = self[k]
|
||||
|
||||
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
|
||||
# compute metric on epoch anyway so state does not accumulate
|
||||
self[k].compute()
|
||||
|
||||
return result
|
||||
|
||||
def get_epoch_pbar_metrics(self):
|
||||
"""
|
||||
Gets the metrics to log at the end of the batch step
|
||||
Gets the metrics to log at the end of epoch
|
||||
"""
|
||||
result = {}
|
||||
|
||||
|
@ -278,8 +291,17 @@ class Result(Dict):
|
|||
for k, options in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
||||
if options['prog_bar'] and options['on_epoch']:
|
||||
result[k] = self[k]
|
||||
if isinstance(self[k], Metric):
|
||||
result[k] = self[k].compute()
|
||||
else:
|
||||
result[k] = self[k]
|
||||
|
||||
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
|
||||
# compute metric on epoch anyway so state does not accumulate
|
||||
self[k].compute()
|
||||
|
||||
return result
|
||||
|
||||
def get_batch_pbar_metrics(self, include_forked_originals=True):
|
||||
|
@ -292,11 +314,16 @@ class Result(Dict):
|
|||
for k, options in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
||||
if options['forked'] and not include_forked_originals:
|
||||
continue
|
||||
|
||||
if options['prog_bar'] and options['on_step']:
|
||||
result[k] = self[k]
|
||||
if isinstance(self[k], Metric):
|
||||
result[k] = self[k]._forward_cache
|
||||
else:
|
||||
result[k] = self[k]
|
||||
|
||||
return result
|
||||
|
||||
def detach(self):
|
||||
|
@ -405,7 +432,7 @@ class Result(Dict):
|
|||
recursive_stack(result)
|
||||
|
||||
for k, option in meta.items():
|
||||
if k == '_internal':
|
||||
if k == '_internal' or isinstance(result[k], Metric):
|
||||
continue
|
||||
|
||||
if option['on_epoch']:
|
||||
|
@ -439,7 +466,7 @@ class Result(Dict):
|
|||
recursive_stack(result)
|
||||
|
||||
for k, value in result.items():
|
||||
if k in ['meta', 'extra']:
|
||||
if k in ['meta', 'extra'] or isinstance(value, Metric):
|
||||
continue
|
||||
|
||||
# pick the reduce fx
|
||||
|
@ -459,10 +486,12 @@ class Result(Dict):
|
|||
|
||||
def dp_reduce(self):
|
||||
for k, value in self.items():
|
||||
if k == 'meta':
|
||||
if k == 'meta' or isinstance(value, Metric):
|
||||
continue
|
||||
|
||||
if isinstance(value, list):
|
||||
value = torch.tensor(value)
|
||||
|
||||
self[k] = value.mean(dim=-1)
|
||||
|
||||
@property
|
||||
|
@ -502,10 +531,14 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] =
|
|||
v = recursive_gather([v], in_d)
|
||||
result[k] = v
|
||||
else:
|
||||
if k not in result:
|
||||
result[k] = []
|
||||
|
||||
result[k].append(v)
|
||||
if isinstance(v, Metric):
|
||||
# if v is a metric, just keep one of them,
|
||||
# don't keep on adding a list of them
|
||||
result[k] = v
|
||||
else:
|
||||
if k not in result:
|
||||
result[k] = []
|
||||
result[k].append(v)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ class Metric(nn.Module, ABC):
|
|||
self.update = self._wrap_update(self.update)
|
||||
self.compute = self._wrap_compute(self.compute)
|
||||
self._computed = None
|
||||
self._forward_cache = None
|
||||
|
||||
# initialize state
|
||||
self._reductions = {}
|
||||
|
@ -125,6 +126,7 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
# add current step
|
||||
self.update(*args, **kwargs)
|
||||
self._forward_cache = None
|
||||
|
||||
if self.compute_on_step:
|
||||
self._to_sync = self.ddp_sync_on_step
|
||||
|
@ -135,7 +137,7 @@ class Metric(nn.Module, ABC):
|
|||
# call reset, update, compute, on single batch
|
||||
self.reset()
|
||||
self.update(*args, **kwargs)
|
||||
result = self.compute()
|
||||
self._forward_cache = self.compute()
|
||||
|
||||
# restore context
|
||||
for attr, val in self._cache.items():
|
||||
|
@ -143,7 +145,7 @@ class Metric(nn.Module, ABC):
|
|||
self._to_sync = True
|
||||
self._computed = None
|
||||
|
||||
return result
|
||||
return self._forward_cache
|
||||
|
||||
def _sync_dist(self):
|
||||
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
|
||||
import pytest
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.metrics import Metric
|
||||
import tests.base.develop_utils as tutils
|
||||
|
||||
|
||||
class DummyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
import os
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def _ddp_test_fn(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
tensor = torch.tensor([1.0])
|
||||
|
||||
metric_a = DummyMetric()
|
||||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
||||
# ddp_sync_on_step is False by default
|
||||
result = Result()
|
||||
|
||||
for epoch in range(3):
|
||||
cumulative_sum = 0
|
||||
|
||||
for i in range(5):
|
||||
metric_a(i)
|
||||
metric_b(i)
|
||||
metric_c(i)
|
||||
|
||||
cumulative_sum += i
|
||||
|
||||
result.log('a', metric_a, on_step=True, on_epoch=True)
|
||||
result.log('b', metric_b, on_step=False, on_epoch=True)
|
||||
result.log('c', metric_c, on_step=True, on_epoch=False)
|
||||
|
||||
batch_log = result.get_batch_log_metrics()
|
||||
batch_expected = {"a_step": i, "a": i, "c": i}
|
||||
assert set(batch_log.keys()) == set(batch_expected.keys())
|
||||
for k in batch_expected.keys():
|
||||
assert batch_expected[k] == batch_log[k]
|
||||
|
||||
epoch_log = result.get_epoch_log_metrics()
|
||||
|
||||
# assert metric state reset to default values
|
||||
assert metric_a.x == metric_a._defaults['x']
|
||||
assert metric_b.x == metric_b._defaults['x']
|
||||
assert metric_c.x == metric_c._defaults['x']
|
||||
|
||||
epoch_expected = {
|
||||
"b": cumulative_sum * worldsize,
|
||||
"a": cumulative_sum * worldsize,
|
||||
"a_epoch": cumulative_sum * worldsize
|
||||
}
|
||||
|
||||
assert set(epoch_log.keys()) == set(epoch_expected.keys())
|
||||
for k in epoch_expected.keys():
|
||||
assert epoch_expected[k] == epoch_log[k]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_result_reduce_ddp():
|
||||
"""Make sure result logging works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
||||
|
||||
|
||||
def test_result_metric_integration():
|
||||
metric_a = DummyMetric()
|
||||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
||||
result = Result()
|
||||
|
||||
for epoch in range(3):
|
||||
cumulative_sum = 0
|
||||
|
||||
for i in range(5):
|
||||
metric_a(i)
|
||||
metric_b(i)
|
||||
metric_c(i)
|
||||
|
||||
cumulative_sum += i
|
||||
|
||||
result.log('a', metric_a, on_step=True, on_epoch=True)
|
||||
result.log('b', metric_b, on_step=False, on_epoch=True)
|
||||
result.log('c', metric_c, on_step=True, on_epoch=False)
|
||||
|
||||
batch_log = result.get_batch_log_metrics()
|
||||
batch_expected = {"a_step": i, "a": i, "c": i}
|
||||
assert set(batch_log.keys()) == set(batch_expected.keys())
|
||||
for k in batch_expected.keys():
|
||||
assert batch_expected[k] == batch_log[k]
|
||||
|
||||
epoch_log = result.get_epoch_log_metrics()
|
||||
|
||||
# assert metric state reset to default values
|
||||
assert metric_a.x == metric_a._defaults['x']
|
||||
assert metric_b.x == metric_b._defaults['x']
|
||||
assert metric_c.x == metric_c._defaults['x']
|
||||
|
||||
epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}
|
||||
|
||||
assert set(epoch_log.keys()) == set(epoch_expected.keys())
|
||||
for k in epoch_expected.keys():
|
||||
assert epoch_expected[k] == epoch_log[k]
|
|
@ -111,6 +111,24 @@ def test_compute():
|
|||
assert a.compute() == 5
|
||||
|
||||
|
||||
def test_forward():
|
||||
class A(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
a = A()
|
||||
assert a(5) == 5
|
||||
assert a._forward_cache == 5
|
||||
|
||||
assert a(8) == 8
|
||||
assert a._forward_cache == 8
|
||||
|
||||
assert a.compute() == 13
|
||||
|
||||
|
||||
class ToPickle(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
|
Loading…
Reference in New Issue