diff --git a/README.md b/README.md index 87c39d0b03..6c7220775e 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,11 @@ class LitAutoEncoder(pl.LightningModule): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding def training_step(self, batch, batch_idx): x, y = batch @@ -150,12 +155,10 @@ class LitAutoEncoder(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer - - # def forward(self, x): - # in lightning this is optional and mostly used to say - # how your LightningModule should work for inference/predictions ``` +###### Note: Training_step defines the training loop. Forward defines how the LightningModule behaves during inference/prediction. + #### Step 2: Train! ```python diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst index 6769152824..d17d3af7e0 100644 --- a/docs/source/new-project.rst +++ b/docs/source/new-project.rst @@ -83,7 +83,6 @@ Step 1: Define LightningModule .. code-block:: - class LitAutoEncoder(pl.LightningModule): def __init__(self): @@ -99,6 +98,11 @@ Step 1: Define LightningModule nn.Linear(64, 28*28) ) + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + def training_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) @@ -111,10 +115,6 @@ Step 1: Define LightningModule optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer - # def forward(self, x): - # in lightning this is optional and mostly used to say - # how your LightningModule should work for inference/predictions - A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as: - `Autoencoder `_ @@ -146,6 +146,9 @@ of the 20+ hooks found in :ref:`hooks` More details in :ref:`lightning_module` docs. +.. note:: The training_step defines the training loop and forward defines the prediction/inference behavior. + In this case, we want to use our Autoencoder for generating image embeddings. + ---------- **********************************