clarify forward (#3611)

* clarify forward

* clarify forward

* clarify forward

* clarify forward
This commit is contained in:
William Falcon 2020-09-22 14:00:02 -04:00 committed by GitHub
parent 2a10cfaf3d
commit f53e739637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 9 deletions

View File

@ -138,6 +138,11 @@ class LitAutoEncoder(pl.LightningModule):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
@ -150,12 +155,10 @@ class LitAutoEncoder(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# def forward(self, x):
# in lightning this is optional and mostly used to say
# how your LightningModule should work for inference/predictions
```
###### Note: Training_step defines the training loop. Forward defines how the LightningModule behaves during inference/prediction.
#### Step 2: Train!
```python

View File

@ -83,7 +83,6 @@ Step 1: Define LightningModule
.. code-block::
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
@ -99,6 +98,11 @@ Step 1: Define LightningModule
nn.Linear(64, 28*28)
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
@ -111,10 +115,6 @@ Step 1: Define LightningModule
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# def forward(self, x):
# in lightning this is optional and mostly used to say
# how your LightningModule should work for inference/predictions
A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as:
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py>`_
@ -146,6 +146,9 @@ of the 20+ hooks found in :ref:`hooks`
More details in :ref:`lightning_module` docs.
.. note:: The training_step defines the training loop and forward defines the prediction/inference behavior.
In this case, we want to use our Autoencoder for generating image embeddings.
----------
**********************************