diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index f3f37cddad..1d32c69b98 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -114,8 +114,9 @@ class LitAutoEncoder(LightningModule): ) """ - def __init__(self, hidden_dim: int = 64): + def __init__(self, hidden_dim: int = 64, learning_rate=10e-3): super().__init__() + self.save_hyperparameters() self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3)) self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28)) @@ -138,7 +139,7 @@ class LitAutoEncoder(LightningModule): return self(x) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return optimizer def _prepare_batch(self, batch):