Improve callback documentation for `outputs` and `accumulate_grad_batches` (Resolves #15315) (#15327)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Robert Bracco 2022-10-30 19:58:27 -04:00 committed by GitHub
parent c39c8eb2e4
commit a008801e25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -82,7 +82,12 @@ class Callback:
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the train batch ends."""
"""Called when the train batch ends.
Note:
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
loss returned from ``training_step``.
"""
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train epoch begins."""

View File

@ -661,6 +661,10 @@ class LightningModule(
Note:
The loss value shown in the progress bar is smoothed (averaged) over the last values,
so it differs from the actual loss returned in train/validation step.
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
"""
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")