integrate metrics API with self.log (#3961)

* metrics integration into self.log

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* ddp and regualr test for self.log + metrics

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* pep8

* fix log tests

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* docs

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
Ananya Harsh Jha 2020-10-07 22:54:32 -04:00 committed by GitHub
parent aa95addff2
commit 6f1a2ce517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 235 additions and 22 deletions

View File

@ -25,12 +25,6 @@ logic present in ``.compute()`` is applied to state information from all process
The example below shows how to use a metric in your ``LightningModule``:
.. note::
For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
This has been shown in the example below. For v1.0 release, we will integrate metrics
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.
.. code-block:: python
def __init__(self):
@ -49,6 +43,41 @@ The example below shows how to use a metric in your ``LightningModule``:
self.log('train_acc_epoch', self.accuracy.compute())
``Metric`` objects can also be directly logged, in which case Lightning will log
the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``.
If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling
``.compute()``.
.. note::
``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx``
flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class
contains its own distributed synchronization logic.
This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.
.. code-block:: python
def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_acc(logits, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. code-block:: python

View File

@ -21,7 +21,7 @@ from torch import Tensor
import os
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.metrics import Metric
class Result(Dict):
def __init__(
@ -251,12 +251,16 @@ class Result(Dict):
continue
if options['logger'] and options['on_step']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
else:
result[k] = self[k]
return result
def get_epoch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
Gets the metrics to log at the end of epoch
"""
result = {}
@ -264,13 +268,22 @@ class Result(Dict):
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_epoch']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k].compute()
else:
result[k] = self[k]
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()
return result
def get_epoch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
Gets the metrics to log at the end of epoch
"""
result = {}
@ -278,8 +291,17 @@ class Result(Dict):
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_epoch']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k].compute()
else:
result[k] = self[k]
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()
return result
def get_batch_pbar_metrics(self, include_forked_originals=True):
@ -292,11 +314,16 @@ class Result(Dict):
for k, options in meta.items():
if k == '_internal':
continue
if options['forked'] and not include_forked_originals:
continue
if options['prog_bar'] and options['on_step']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
else:
result[k] = self[k]
return result
def detach(self):
@ -405,7 +432,7 @@ class Result(Dict):
recursive_stack(result)
for k, option in meta.items():
if k == '_internal':
if k == '_internal' or isinstance(result[k], Metric):
continue
if option['on_epoch']:
@ -439,7 +466,7 @@ class Result(Dict):
recursive_stack(result)
for k, value in result.items():
if k in ['meta', 'extra']:
if k in ['meta', 'extra'] or isinstance(value, Metric):
continue
# pick the reduce fx
@ -459,10 +486,12 @@ class Result(Dict):
def dp_reduce(self):
for k, value in self.items():
if k == 'meta':
if k == 'meta' or isinstance(value, Metric):
continue
if isinstance(value, list):
value = torch.tensor(value)
self[k] = value.mean(dim=-1)
@property
@ -502,10 +531,14 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] =
v = recursive_gather([v], in_d)
result[k] = v
else:
if k not in result:
result[k] = []
result[k].append(v)
if isinstance(v, Metric):
# if v is a metric, just keep one of them,
# don't keep on adding a list of them
result[k] = v
else:
if k not in result:
result[k] = []
result[k].append(v)
return result

View File

@ -59,6 +59,7 @@ class Metric(nn.Module, ABC):
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
self._forward_cache = None
# initialize state
self._reductions = {}
@ -125,6 +126,7 @@ class Metric(nn.Module, ABC):
"""
# add current step
self.update(*args, **kwargs)
self._forward_cache = None
if self.compute_on_step:
self._to_sync = self.ddp_sync_on_step
@ -135,7 +137,7 @@ class Metric(nn.Module, ABC):
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
result = self.compute()
self._forward_cache = self.compute()
# restore context
for attr, val in self._cache.items():
@ -143,7 +145,7 @@ class Metric(nn.Module, ABC):
self._to_sync = True
self._computed = None
return result
return self._forward_cache
def _sync_dist(self):
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}

View File

@ -0,0 +1,131 @@
import pytest
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.metrics import Metric
import tests.base.develop_utils as tutils
class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
def update(self, x):
self.x += x
def compute(self):
return self.x
def _setup_ddp(rank, worldsize):
import os
os.environ["MASTER_ADDR"] = "localhost"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
# ddp_sync_on_step is False by default
result = Result()
for epoch in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)
batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]
epoch_log = result.get_epoch_log_metrics()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
epoch_expected = {
"b": cumulative_sum * worldsize,
"a": cumulative_sum * worldsize,
"a_epoch": cumulative_sum * worldsize
}
assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
def test_result_metric_integration():
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
result = Result()
for epoch in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)
batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]
epoch_log = result.get_epoch_log_metrics()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}
assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]

View File

@ -111,6 +111,24 @@ def test_compute():
assert a.compute() == 5
def test_forward():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert a(5) == 5
assert a._forward_cache == 5
assert a(8) == 8
assert a._forward_cache == 8
assert a.compute() == 13
class ToPickle(Dummy):
def update(self, x):
self.x += x