From cafdc6d3089407abb764e4fd5fcf119af7f055c0 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 28 Apr 2023 10:54:09 +0100 Subject: [PATCH] Examples: expose learning rate (#17513) Co-authored-by: thomas --- examples/pytorch/basics/autoencoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index f3f37cddad..1d32c69b98 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -114,8 +114,9 @@ class LitAutoEncoder(LightningModule): ) """ - def __init__(self, hidden_dim: int = 64): + def __init__(self, hidden_dim: int = 64, learning_rate=10e-3): super().__init__() + self.save_hyperparameters() self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3)) self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28)) @@ -138,7 +139,7 @@ class LitAutoEncoder(LightningModule): return self(x) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return optimizer def _prepare_batch(self, batch):