diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ca46178544..8bd200c2d9 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -25,12 +25,6 @@ logic present in ``.compute()`` is applied to state information from all process The example below shows how to use a metric in your ``LightningModule``: -.. note:: - - For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch. - This has been shown in the example below. For v1.0 release, we will integrate metrics - with logging and ``.compute()`` will be called automatically by PyTorch Lightning. - .. code-block:: python def __init__(self): @@ -49,6 +43,41 @@ The example below shows how to use a metric in your ``LightningModule``: self.log('train_acc_epoch', self.accuracy.compute()) +``Metric`` objects can also be directly logged, in which case Lightning will log +the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. +If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling +``.compute()``. + +.. note:: + ``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx`` + flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class + contains its own distributed synchronization logic. + + This however is only true for metrics that inherit the base class ``Metric``, + and thus the functional metric API provides no support for in-built distributed synchronization + or reduction functions. + + +.. code-block:: python + + def __init__(self): + ... + self.train_acc = pl.metrics.Accuracy() + self.valid_acc = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + logits = self(x) + ... + self.train_acc(logits, y) + self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_acc(logits, y) + self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) + + This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: .. code-block:: python diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 8e4d6bdf8b..ad34261f1e 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -21,7 +21,7 @@ from torch import Tensor import os from pytorch_lightning.utilities.distributed import sync_ddp_if_available - +from pytorch_lightning.metrics import Metric class Result(Dict): def __init__( @@ -251,12 +251,16 @@ class Result(Dict): continue if options['logger'] and options['on_step']: - result[k] = self[k] + if isinstance(self[k], Metric): + result[k] = self[k]._forward_cache + else: + result[k] = self[k] + return result def get_epoch_log_metrics(self) -> dict: """ - Gets the metrics to log at the end of the batch step + Gets the metrics to log at the end of epoch """ result = {} @@ -264,13 +268,22 @@ class Result(Dict): for k, options in meta.items(): if k == '_internal': continue + if options['logger'] and options['on_epoch']: - result[k] = self[k] + if isinstance(self[k], Metric): + result[k] = self[k].compute() + else: + result[k] = self[k] + + if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # compute metric on epoch anyway so state does not accumulate + self[k].compute() + return result def get_epoch_pbar_metrics(self): """ - Gets the metrics to log at the end of the batch step + Gets the metrics to log at the end of epoch """ result = {} @@ -278,8 +291,17 @@ class Result(Dict): for k, options in meta.items(): if k == '_internal': continue + if options['prog_bar'] and options['on_epoch']: - result[k] = self[k] + if isinstance(self[k], Metric): + result[k] = self[k].compute() + else: + result[k] = self[k] + + if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # compute metric on epoch anyway so state does not accumulate + self[k].compute() + return result def get_batch_pbar_metrics(self, include_forked_originals=True): @@ -292,11 +314,16 @@ class Result(Dict): for k, options in meta.items(): if k == '_internal': continue + if options['forked'] and not include_forked_originals: continue if options['prog_bar'] and options['on_step']: - result[k] = self[k] + if isinstance(self[k], Metric): + result[k] = self[k]._forward_cache + else: + result[k] = self[k] + return result def detach(self): @@ -405,7 +432,7 @@ class Result(Dict): recursive_stack(result) for k, option in meta.items(): - if k == '_internal': + if k == '_internal' or isinstance(result[k], Metric): continue if option['on_epoch']: @@ -439,7 +466,7 @@ class Result(Dict): recursive_stack(result) for k, value in result.items(): - if k in ['meta', 'extra']: + if k in ['meta', 'extra'] or isinstance(value, Metric): continue # pick the reduce fx @@ -459,10 +486,12 @@ class Result(Dict): def dp_reduce(self): for k, value in self.items(): - if k == 'meta': + if k == 'meta' or isinstance(value, Metric): continue + if isinstance(value, list): value = torch.tensor(value) + self[k] = value.mean(dim=-1) @property @@ -502,10 +531,14 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = v = recursive_gather([v], in_d) result[k] = v else: - if k not in result: - result[k] = [] - - result[k].append(v) + if isinstance(v, Metric): + # if v is a metric, just keep one of them, + # don't keep on adding a list of them + result[k] = v + else: + if k not in result: + result[k] = [] + result[k].append(v) return result diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 1723a65ac8..acd2b2d5e2 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -59,6 +59,7 @@ class Metric(nn.Module, ABC): self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) self._computed = None + self._forward_cache = None # initialize state self._reductions = {} @@ -125,6 +126,7 @@ class Metric(nn.Module, ABC): """ # add current step self.update(*args, **kwargs) + self._forward_cache = None if self.compute_on_step: self._to_sync = self.ddp_sync_on_step @@ -135,7 +137,7 @@ class Metric(nn.Module, ABC): # call reset, update, compute, on single batch self.reset() self.update(*args, **kwargs) - result = self.compute() + self._forward_cache = self.compute() # restore context for attr, val in self._cache.items(): @@ -143,7 +145,7 @@ class Metric(nn.Module, ABC): self._to_sync = True self._computed = None - return result + return self._forward_cache def _sync_dist(self): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py new file mode 100644 index 0000000000..b9cad945bd --- /dev/null +++ b/tests/core/test_metric_result_integration.py @@ -0,0 +1,131 @@ + +import pytest +import sys +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.metrics import Metric +import tests.base.develop_utils as tutils + + +class DummyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0), dist_reduce_fx="sum") + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + +def _setup_ddp(rank, worldsize): + import os + + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def _ddp_test_fn(rank, worldsize): + _setup_ddp(rank, worldsize) + tensor = torch.tensor([1.0]) + + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + + # ddp_sync_on_step is False by default + result = Result() + + for epoch in range(3): + cumulative_sum = 0 + + for i in range(5): + metric_a(i) + metric_b(i) + metric_c(i) + + cumulative_sum += i + + result.log('a', metric_a, on_step=True, on_epoch=True) + result.log('b', metric_b, on_step=False, on_epoch=True) + result.log('c', metric_c, on_step=True, on_epoch=False) + + batch_log = result.get_batch_log_metrics() + batch_expected = {"a_step": i, "a": i, "c": i} + assert set(batch_log.keys()) == set(batch_expected.keys()) + for k in batch_expected.keys(): + assert batch_expected[k] == batch_log[k] + + epoch_log = result.get_epoch_log_metrics() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] + + epoch_expected = { + "b": cumulative_sum * worldsize, + "a": cumulative_sum * worldsize, + "a_epoch": cumulative_sum * worldsize + } + + assert set(epoch_log.keys()) == set(epoch_expected.keys()) + for k in epoch_expected.keys(): + assert epoch_expected[k] == epoch_log[k] + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_result_reduce_ddp(): + """Make sure result logging works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize) + + +def test_result_metric_integration(): + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + + result = Result() + + for epoch in range(3): + cumulative_sum = 0 + + for i in range(5): + metric_a(i) + metric_b(i) + metric_c(i) + + cumulative_sum += i + + result.log('a', metric_a, on_step=True, on_epoch=True) + result.log('b', metric_b, on_step=False, on_epoch=True) + result.log('c', metric_c, on_step=True, on_epoch=False) + + batch_log = result.get_batch_log_metrics() + batch_expected = {"a_step": i, "a": i, "c": i} + assert set(batch_log.keys()) == set(batch_expected.keys()) + for k in batch_expected.keys(): + assert batch_expected[k] == batch_log[k] + + epoch_log = result.get_epoch_log_metrics() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] + + epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum} + + assert set(epoch_log.keys()) == set(epoch_expected.keys()) + for k in epoch_expected.keys(): + assert epoch_expected[k] == epoch_log[k] diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 366c873127..a515e50593 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -111,6 +111,24 @@ def test_compute(): assert a.compute() == 5 +def test_forward(): + class A(Dummy): + def update(self, x): + self.x += x + + def compute(self): + return self.x + + a = A() + assert a(5) == 5 + assert a._forward_cache == 5 + + assert a(8) == 8 + assert a._forward_cache == 8 + + assert a.compute() == 13 + + class ToPickle(Dummy): def update(self, x): self.x += x