[Docs] Note on running metric in dp (#4494)
* note * Update docs/source/metrics.rst Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
ee35907170
commit
01a925d333
|
@ -78,6 +78,26 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
||||||
self.valid_acc(logits, y)
|
self.valid_acc(logits, y)
|
||||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If using metrics in data parallel mode (dp), the metric update/logging should be done
|
||||||
|
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
|
||||||
|
or ``test``). This is due to metric states else being destroyed after each forward pass,
|
||||||
|
leading to wrong accumulation. In practice do the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
data, target = batch
|
||||||
|
pred = self(data)
|
||||||
|
...
|
||||||
|
return {'loss' : loss, 'preds' : preds, 'target' : target}
|
||||||
|
|
||||||
|
def training_step_end(self, outputs):
|
||||||
|
#update and log
|
||||||
|
self.metric(outputs['preds'], outputs['target'])
|
||||||
|
self.log('metric', self.metric)
|
||||||
|
|
||||||
|
|
||||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
Loading…
Reference in New Issue