lightning/pl_examples/basic_examples/autoencoder.py

122 lines
4.0 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST autoencoder example.
To run: python autoencoder.py --trainer.max_epochs=50
"""
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.utilities.cli import LightningCLI
2021-04-13 16:33:32 +00:00
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
2021-03-11 11:19:48 +00:00
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
class LitAutoEncoder(pl.LightningModule):
"""
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitAutoEncoder(
(encoder): ...
(decoder): ...
)
"""
def __init__(self, hidden_dim: int = 64):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("valid_loss", loss, on_step=True)
def test_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("test_loss", loss, on_step=True)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
return self.decoder(z)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class MyDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 32):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def cli_main():
cli = LightningCLI(
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])
if __name__ == "__main__":
cli_lightning_logo()
cli_main()