From 01a925d3331042201b4e0a6c2a0bcb4e154686aa Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 9 Nov 2020 11:30:28 +0100 Subject: [PATCH] [Docs] Note on running metric in dp (#4494) * note * Update docs/source/metrics.rst Co-authored-by: chaton Co-authored-by: Sean Naren Co-authored-by: Jeff Yang --- docs/source/metrics.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4fadfaa507..e41fdfa7d1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -78,6 +78,26 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v self.valid_acc(logits, y) self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) +.. note:: + If using metrics in data parallel mode (dp), the metric update/logging should be done + in the ``_step_end`` method (where ```` is either ``training``, ``validation`` + or ``test``). This is due to metric states else being destroyed after each forward pass, + leading to wrong accumulation. In practice do the following: + + .. code-block:: python + + def training_step(self, batch, batch_idx): + data, target = batch + pred = self(data) + ... + return {'loss' : loss, 'preds' : preds, 'target' : target} + + def training_step_end(self, outputs): + #update and log + self.metric(outputs['preds'], outputs['target']) + self.log('metric', self.metric) + + This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: .. code-block:: python