diff --git a/docs/source-pytorch/model/train_model_basic.rst b/docs/source-pytorch/model/train_model_basic.rst index 88799a1451..92f4a0a40f 100644 --- a/docs/source-pytorch/model/train_model_basic.rst +++ b/docs/source-pytorch/model/train_model_basic.rst @@ -116,11 +116,11 @@ Under the hood, the Lightning Trainer runs the following training loop on your b .. code:: python - autoencoder = LitAutoEncoder(encoder, decoder) + autoencoder = LitAutoEncoder(Encoder(), Decoder()) optimizer = autoencoder.configure_optimizers() for batch_idx, batch in enumerate(train_loader): - loss = autoencoder(batch, batch_idx) + loss = autoencoder.training_step(batch, batch_idx) loss.backward() optimizer.step()