Improve callback documentation for `outputs` and `accumulate_grad_batches` (Resolves #15315) (#15327)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Robert Bracco 2022-10-30 19:58:27 -04:00 committed by GitHub
parent c39c8eb2e4
commit a008801e25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -82,7 +82,12 @@ class Callback:
def on_train_batch_end( def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None: ) -> None:
"""Called when the train batch ends.""" """Called when the train batch ends.
Note:
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
loss returned from ``training_step``.
"""
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train epoch begins.""" """Called when the train epoch begins."""

View File

@ -661,6 +661,10 @@ class LightningModule(
Note: Note:
The loss value shown in the progress bar is smoothed (averaged) over the last values, The loss value shown in the progress bar is smoothed (averaged) over the last values,
so it differs from the actual loss returned in train/validation step. so it differs from the actual loss returned in train/validation step.
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
""" """
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")