Fixed example implementation of AutoEncoder. (#3190)
The previous implementation trained a auto encoder and evaluated classificator. I try to fix this by replacing the evaluation metric with an auto encoder metric. Hence, no classification is done. I'm not 100% sure what the original authors intent was, since he extends a classification model (LitMNIST) but does not use it. The following model is an AutoEncoder and does not do any classification. 1. Small textual changes. 2. forward() now implements encoding and not decoding (as it was described in the text.) 3. _shared_eval uses MSE loss instead of class loss, since no classification weights are learned. 4. initialized MSE in __init__, since calling MSE directly is not supported.
This commit is contained in:
parent
17d8773106
commit
ae3bf919c6
|
@ -16,14 +16,17 @@
|
||||||
def val_dataloader():
|
def val_dataloader():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_dataloader():
|
||||||
|
pass
|
||||||
|
|
||||||
Child Modules
|
Child Modules
|
||||||
-------------
|
-------------
|
||||||
Research projects tend to test different approaches to the same dataset.
|
Research projects tend to test different approaches to the same dataset.
|
||||||
This is very easy to do in Lightning with inheritance.
|
This is very easy to do in Lightning with inheritance.
|
||||||
|
|
||||||
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
|
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
|
||||||
Recall that `LitMNIST` already defines all the dataloading etc... The only things
|
We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading.
|
||||||
that change in the `Autoencoder` model are the init, forward, training, validation and test step.
|
The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step.
|
||||||
|
|
||||||
.. testcode::
|
.. testcode::
|
||||||
|
|
||||||
|
@ -39,18 +42,18 @@ that change in the `Autoencoder` model are the init, forward, training, validati
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = Encoder()
|
self.encoder = Encoder()
|
||||||
self.decoder = Decoder()
|
self.decoder = Decoder()
|
||||||
|
self.metric = MSE()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
generated = self.decoder(x)
|
return self.encoder(x)
|
||||||
return generated
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, _ = batch
|
x, _ = batch
|
||||||
|
|
||||||
representation = self.encoder(x)
|
representation = self(x)
|
||||||
x_hat = self(representation)
|
x_hat = self.decoder(representation)
|
||||||
|
|
||||||
loss = MSE(x, x_hat)
|
loss = self.metric(x, x_hat)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
|
@ -60,11 +63,11 @@ that change in the `Autoencoder` model are the init, forward, training, validati
|
||||||
return self._shared_eval(batch, batch_idx, 'test')
|
return self._shared_eval(batch, batch_idx, 'test')
|
||||||
|
|
||||||
def _shared_eval(self, batch, batch_idx, prefix):
|
def _shared_eval(self, batch, batch_idx, prefix):
|
||||||
x, y = batch
|
x, _ = batch
|
||||||
representation = self.encoder(x)
|
representation = self(x)
|
||||||
x_hat = self(representation)
|
x_hat = self.decoder(representation)
|
||||||
|
|
||||||
loss = F.nll_loss(logits, y)
|
loss = self.metric(x, x_hat)
|
||||||
result = pl.EvalResult()
|
result = pl.EvalResult()
|
||||||
result.log(f'{prefix}_loss', loss)
|
result.log(f'{prefix}_loss', loss)
|
||||||
return result
|
return result
|
||||||
|
@ -78,7 +81,7 @@ and we can train this using the same trainer
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
trainer.fit(autoencoder)
|
trainer.fit(autoencoder)
|
||||||
|
|
||||||
And remember that the forward method is to define the practical use of a LightningModule.
|
And remember that the forward method should define the practical use of a LightningModule.
|
||||||
In this case, we want to use the `AutoEncoder` to extract image representations
|
In this case, we want to use the `AutoEncoder` to extract image representations
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
Loading…
Reference in New Issue