[Refactor] Improve auto-encoder example (#9402)

This commit is contained in:
thomas chaton 2021-09-10 17:08:11 +01:00 committed by GitHub
parent 81687aade1
commit ffd275f5b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 96 additions and 26 deletions

View File

@ -15,6 +15,7 @@
To run: python autoencoder.py --trainer.max_epochs=50
"""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@ -24,11 +25,82 @@ from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
class ImageSampler(pl.callbacks.Callback):
def __init__(
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = True,
norm_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
) -> None:
"""
Args:
num_samples: Number of images displayed in the grid. Default: ``3``.
nrow: Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding: Amount of padding. Default: ``2``.
normalize: If ``True``, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
norm_range: Tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each: If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
super().__init__()
self.num_samples = num_samples
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.norm_range = norm_range
self.scale_each = scale_each
self.pad_value = pad_value
def _to_grid(self, images):
return torchvision.utils.make_grid(
tensor=images,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.norm_range,
scale_each=self.scale_each,
pad_value=self.pad_value,
)
@rank_zero_only
def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not _TORCHVISION_AVAILABLE:
return
images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
images_flattened = images.view(images.size(0), -1)
# generate images
with torch.no_grad():
pl_module.eval()
images_generated = pl_module(images_flattened.to(pl_module.device))
pl_module.train()
if trainer.current_epoch == 0:
save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")
class LitAutoEncoder(pl.LightningModule):
@ -46,44 +118,37 @@ class LitAutoEncoder(pl.LightningModule):
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
z = self.encoder(x)
x_hat = self.decoder(z)
return x_hat
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
return self._common_step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("valid_loss", loss, on_step=True)
self._common_step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("test_loss", loss, on_step=True)
self._common_step(batch, batch_idx, "test")
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
return self.decoder(z)
x = self._prepare_batch(batch)
return self(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def _prepare_batch(self, batch):
x, _ = batch
return x.view(x.size(0), -1)
def _common_step(self, batch, batch_idx, stage: str):
x = self._prepare_batch(batch)
loss = F.mse_loss(x, self(x))
self.log(f"{stage}_loss", loss, on_step=True)
return loss
class MyDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 32):
@ -108,7 +173,12 @@ class MyDataModule(pl.LightningDataModule):
def cli_main():
cli = LightningCLI(
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
LitAutoEncoder,
MyDataModule,
seed_everything_default=1234,
save_config_overwrite=True,
run=False, # used to de-activate automatic fitting.
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")