[Docs] Note on running metric in dp (#4494)
* note * Update docs/source/metrics.rst Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
ee35907170
commit
01a925d333
|
@ -78,6 +78,26 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
|||
self.valid_acc(logits, y)
|
||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||
|
||||
.. note::
|
||||
If using metrics in data parallel mode (dp), the metric update/logging should be done
|
||||
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
|
||||
or ``test``). This is due to metric states else being destroyed after each forward pass,
|
||||
leading to wrong accumulation. In practice do the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, target = batch
|
||||
pred = self(data)
|
||||
...
|
||||
return {'loss' : loss, 'preds' : preds, 'target' : target}
|
||||
|
||||
def training_step_end(self, outputs):
|
||||
#update and log
|
||||
self.metric(outputs['preds'], outputs['target'])
|
||||
self.log('metric', self.metric)
|
||||
|
||||
|
||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||
|
||||
.. code-block:: python
|
||||
|
|
Loading…
Reference in New Issue