[Docs] Note on running metric in dp (#4494)

* note

* Update docs/source/metrics.rst

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
Nicki Skafte 2020-11-09 11:30:28 +01:00 committed by GitHub
parent ee35907170
commit 01a925d333
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 0 deletions

View File

@ -78,6 +78,26 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
.. note::
If using metrics in data parallel mode (dp), the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
or ``test``). This is due to metric states else being destroyed after each forward pass,
leading to wrong accumulation. In practice do the following:
.. code-block:: python
def training_step(self, batch, batch_idx):
data, target = batch
pred = self(data)
...
return {'loss' : loss, 'preds' : preds, 'target' : target}
def training_step_end(self, outputs):
#update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. code-block:: python