clarify forward (#3611)
* clarify forward * clarify forward * clarify forward * clarify forward
This commit is contained in:
parent
2a10cfaf3d
commit
f53e739637
11
README.md
11
README.md
|
@ -138,6 +138,11 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
super().__init__()
|
||||
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
|
||||
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
|
||||
|
||||
def forward(self, x):
|
||||
# in lightning, forward defines the prediction/inference actions
|
||||
embedding = self.encoder(x)
|
||||
return embedding
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
@ -150,12 +155,10 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
# def forward(self, x):
|
||||
# in lightning this is optional and mostly used to say
|
||||
# how your LightningModule should work for inference/predictions
|
||||
```
|
||||
|
||||
###### Note: Training_step defines the training loop. Forward defines how the LightningModule behaves during inference/prediction.
|
||||
|
||||
#### Step 2: Train!
|
||||
|
||||
```python
|
||||
|
|
|
@ -83,7 +83,6 @@ Step 1: Define LightningModule
|
|||
|
||||
.. code-block::
|
||||
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -99,6 +98,11 @@ Step 1: Define LightningModule
|
|||
nn.Linear(64, 28*28)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# in lightning, forward defines the prediction/inference actions
|
||||
embedding = self.encoder(x)
|
||||
return embedding
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
@ -111,10 +115,6 @@ Step 1: Define LightningModule
|
|||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
# def forward(self, x):
|
||||
# in lightning this is optional and mostly used to say
|
||||
# how your LightningModule should work for inference/predictions
|
||||
|
||||
A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as:
|
||||
|
||||
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py>`_
|
||||
|
@ -146,6 +146,9 @@ of the 20+ hooks found in :ref:`hooks`
|
|||
|
||||
More details in :ref:`lightning_module` docs.
|
||||
|
||||
.. note:: The training_step defines the training loop and forward defines the prediction/inference behavior.
|
||||
In this case, we want to use our Autoencoder for generating image embeddings.
|
||||
|
||||
----------
|
||||
|
||||
**********************************
|
||||
|
|
Loading…
Reference in New Issue