clarify forward (#3611)
* clarify forward * clarify forward * clarify forward * clarify forward
This commit is contained in:
parent
2a10cfaf3d
commit
f53e739637
11
README.md
11
README.md
|
@ -138,6 +138,11 @@ class LitAutoEncoder(pl.LightningModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
|
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
|
||||||
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
|
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# in lightning, forward defines the prediction/inference actions
|
||||||
|
embedding = self.encoder(x)
|
||||||
|
return embedding
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
|
@ -150,12 +155,10 @@ class LitAutoEncoder(pl.LightningModule):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
# def forward(self, x):
|
|
||||||
# in lightning this is optional and mostly used to say
|
|
||||||
# how your LightningModule should work for inference/predictions
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
###### Note: Training_step defines the training loop. Forward defines how the LightningModule behaves during inference/prediction.
|
||||||
|
|
||||||
#### Step 2: Train!
|
#### Step 2: Train!
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -83,7 +83,6 @@ Step 1: Define LightningModule
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
|
|
||||||
class LitAutoEncoder(pl.LightningModule):
|
class LitAutoEncoder(pl.LightningModule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -99,6 +98,11 @@ Step 1: Define LightningModule
|
||||||
nn.Linear(64, 28*28)
|
nn.Linear(64, 28*28)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# in lightning, forward defines the prediction/inference actions
|
||||||
|
embedding = self.encoder(x)
|
||||||
|
return embedding
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
|
@ -111,10 +115,6 @@ Step 1: Define LightningModule
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
# def forward(self, x):
|
|
||||||
# in lightning this is optional and mostly used to say
|
|
||||||
# how your LightningModule should work for inference/predictions
|
|
||||||
|
|
||||||
A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as:
|
A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as:
|
||||||
|
|
||||||
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py>`_
|
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py>`_
|
||||||
|
@ -146,6 +146,9 @@ of the 20+ hooks found in :ref:`hooks`
|
||||||
|
|
||||||
More details in :ref:`lightning_module` docs.
|
More details in :ref:`lightning_module` docs.
|
||||||
|
|
||||||
|
.. note:: The training_step defines the training loop and forward defines the prediction/inference behavior.
|
||||||
|
In this case, we want to use our Autoencoder for generating image embeddings.
|
||||||
|
|
||||||
----------
|
----------
|
||||||
|
|
||||||
**********************************
|
**********************************
|
||||||
|
|
Loading…
Reference in New Issue