From ffd275f5b192c98acc25a6203a101fb0bf0abf04 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 10 Sep 2021 17:08:11 +0100 Subject: [PATCH] [Refactor] Improve auto-encoder example (#9402) --- pl_examples/basic_examples/autoencoder.py | 122 +++++++++++++++++----- 1 file changed, 96 insertions(+), 26 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index a78e728af1..e4890c0867 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -15,6 +15,7 @@ To run: python autoencoder.py --trainer.max_epochs=50 """ +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -24,11 +25,82 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, cli_lightning_logo from pl_examples.basic_examples.mnist_datamodule import MNIST +from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: + import torchvision from torchvision import transforms + from torchvision.utils import save_image + + +class ImageSampler(pl.callbacks.Callback): + def __init__( + self, + num_samples: int = 3, + nrow: int = 8, + padding: int = 2, + normalize: bool = True, + norm_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: int = 0, + ) -> None: + """ + Args: + num_samples: Number of images displayed in the grid. Default: ``3``. + nrow: Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding: Amount of padding. Default: ``2``. + normalize: If ``True``, shift the image to the range (0, 1), + by the min and max values specified by :attr:`range`. Default: ``False``. + norm_range: Tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each: If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value: Value for the padded pixels. Default: ``0``. + """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") + + super().__init__() + self.num_samples = num_samples + self.nrow = nrow + self.padding = padding + self.normalize = normalize + self.norm_range = norm_range + self.scale_each = scale_each + self.pad_value = pad_value + + def _to_grid(self, images): + return torchvision.utils.make_grid( + tensor=images, + nrow=self.nrow, + padding=self.padding, + normalize=self.normalize, + range=self.norm_range, + scale_each=self.scale_each, + pad_value=self.pad_value, + ) + + @rank_zero_only + def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if not _TORCHVISION_AVAILABLE: + return + + images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples))) + images_flattened = images.view(images.size(0), -1) + + # generate images + with torch.no_grad(): + pl_module.eval() + images_generated = pl_module(images_flattened.to(pl_module.device)) + pl_module.train() + + if trainer.current_epoch == 0: + save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png") + save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png") class LitAutoEncoder(pl.LightningModule): @@ -46,44 +118,37 @@ class LitAutoEncoder(pl.LightningModule): self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28)) def forward(self, x): - # in lightning, forward defines the prediction/inference actions - embedding = self.encoder(x) - return embedding + z = self.encoder(x) + x_hat = self.decoder(z) + return x_hat def training_step(self, batch, batch_idx): - x, y = batch - x = x.view(x.size(0), -1) - z = self.encoder(x) - x_hat = self.decoder(z) - loss = F.mse_loss(x_hat, x) - return loss + return self._common_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): - x, y = batch - x = x.view(x.size(0), -1) - z = self.encoder(x) - x_hat = self.decoder(z) - loss = F.mse_loss(x_hat, x) - self.log("valid_loss", loss, on_step=True) + self._common_step(batch, batch_idx, "val") def test_step(self, batch, batch_idx): - x, y = batch - x = x.view(x.size(0), -1) - z = self.encoder(x) - x_hat = self.decoder(z) - loss = F.mse_loss(x_hat, x) - self.log("test_loss", loss, on_step=True) + self._common_step(batch, batch_idx, "test") def predict_step(self, batch, batch_idx, dataloader_idx=None): - x, y = batch - x = x.view(x.size(0), -1) - z = self.encoder(x) - return self.decoder(z) + x = self._prepare_batch(batch) + return self(x) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer + def _prepare_batch(self, batch): + x, _ = batch + return x.view(x.size(0), -1) + + def _common_step(self, batch, batch_idx, stage: str): + x = self._prepare_batch(batch) + loss = F.mse_loss(x, self(x)) + self.log(f"{stage}_loss", loss, on_step=True) + return loss + class MyDataModule(pl.LightningDataModule): def __init__(self, batch_size: int = 32): @@ -108,7 +173,12 @@ class MyDataModule(pl.LightningDataModule): def cli_main(): cli = LightningCLI( - LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False + LitAutoEncoder, + MyDataModule, + seed_everything_default=1234, + save_config_overwrite=True, + run=False, # used to de-activate automatic fitting. + trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10}, ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) cli.trainer.test(ckpt_path="best")