Improve callback documentation for `outputs` and `accumulate_grad_batches` (Resolves #15315) (#15327)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
c39c8eb2e4
commit
a008801e25
|
@ -82,7 +82,12 @@ class Callback:
|
|||
def on_train_batch_end(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
||||
) -> None:
|
||||
"""Called when the train batch ends."""
|
||||
"""Called when the train batch ends.
|
||||
|
||||
Note:
|
||||
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
|
||||
loss returned from ``training_step``.
|
||||
"""
|
||||
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when the train epoch begins."""
|
||||
|
|
|
@ -661,6 +661,10 @@ class LightningModule(
|
|||
Note:
|
||||
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
||||
so it differs from the actual loss returned in train/validation step.
|
||||
|
||||
Note:
|
||||
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
|
||||
normalized by ``accumulate_grad_batches`` internally.
|
||||
"""
|
||||
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
|
|
Loading…
Reference in New Issue