Fixed example implementation of AutoEncoder. (#3190)

The previous implementation trained a auto encoder and evaluated
classificator.
I try to fix this by replacing the evaluation metric with an auto
encoder metric.
Hence, no classification is done.
I'm not 100% sure what the original authors intent was, since he
extends a classification model (LitMNIST) but does not use it.
The following model is an AutoEncoder and does not do any
classification.

 1. Small textual changes.
 2. forward() now implements encoding and not decoding (as it was described
 in the text.)
 3. _shared_eval uses MSE loss instead of class loss, since no
 classification weights are learned.
 4. initialized MSE in __init__, since calling MSE directly is not
 supported.
This commit is contained in:
Lucas Steinmann 2020-08-26 13:33:04 +02:00 committed by GitHub
parent 17d8773106
commit ae3bf919c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 13 deletions

View File

@ -16,14 +16,17 @@
def val_dataloader(): def val_dataloader():
pass pass
def test_dataloader():
pass
Child Modules Child Modules
------------- -------------
Research projects tend to test different approaches to the same dataset. Research projects tend to test different approaches to the same dataset.
This is very easy to do in Lightning with inheritance. This is very easy to do in Lightning with inheritance.
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images. For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
Recall that `LitMNIST` already defines all the dataloading etc... The only things We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading.
that change in the `Autoencoder` model are the init, forward, training, validation and test step. The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step.
.. testcode:: .. testcode::
@ -39,18 +42,18 @@ that change in the `Autoencoder` model are the init, forward, training, validati
super().__init__() super().__init__()
self.encoder = Encoder() self.encoder = Encoder()
self.decoder = Decoder() self.decoder = Decoder()
self.metric = MSE()
def forward(self, x): def forward(self, x):
generated = self.decoder(x) return self.encoder(x)
return generated
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
x, _ = batch x, _ = batch
representation = self.encoder(x) representation = self(x)
x_hat = self(representation) x_hat = self.decoder(representation)
loss = MSE(x, x_hat) loss = self.metric(x, x_hat)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
@ -60,11 +63,11 @@ that change in the `Autoencoder` model are the init, forward, training, validati
return self._shared_eval(batch, batch_idx, 'test') return self._shared_eval(batch, batch_idx, 'test')
def _shared_eval(self, batch, batch_idx, prefix): def _shared_eval(self, batch, batch_idx, prefix):
x, y = batch x, _ = batch
representation = self.encoder(x) representation = self(x)
x_hat = self(representation) x_hat = self.decoder(representation)
loss = F.nll_loss(logits, y) loss = self.metric(x, x_hat)
result = pl.EvalResult() result = pl.EvalResult()
result.log(f'{prefix}_loss', loss) result.log(f'{prefix}_loss', loss)
return result return result
@ -78,7 +81,7 @@ and we can train this using the same trainer
trainer = Trainer() trainer = Trainer()
trainer.fit(autoencoder) trainer.fit(autoencoder)
And remember that the forward method is to define the practical use of a LightningModule. And remember that the forward method should define the practical use of a LightningModule.
In this case, we want to use the `AutoEncoder` to extract image representations In this case, we want to use the `AutoEncoder` to extract image representations
.. code-block:: python .. code-block:: python