Examples: expose learning rate (#17513)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-04-28 10:54:09 +01:00 committed by GitHub
parent 3867045de4
commit cafdc6d308
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -114,8 +114,9 @@ class LitAutoEncoder(LightningModule):
)
"""
def __init__(self, hidden_dim: int = 64):
def __init__(self, hidden_dim: int = 64, learning_rate=10e-3):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))
@ -138,7 +139,7 @@ class LitAutoEncoder(LightningModule):
return self(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
def _prepare_batch(self, batch):