diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 6569552406..f1f862e802 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -136,7 +136,17 @@ def training_step(self, data_batch, batch_nb): # return a dict return output -``` +``` + +If you define multiple optimizers, this step will also be called with an additional ```optimizer_idx``` param. +``` {.python} +# Multiple optimizers (ie: GANs) +def training_step(self, data_batch, batch_nb, optimizer_idx): + if optimizer_idx == 0: + # do training_step with encoder + if optimizer_idx == 1: + # do training_step with decoder +``` --- ### tng_dataloader