{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "02-datamodules.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true, "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "2O5r7QvP8-rt", "colab_type": "text" }, "source": [ "# PyTorch Lightning DataModules ⚡\n", "\n", "With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n", "\n", "This notebook will walk you through how to start using Datamodules.\n", "\n", "The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n", "\n", "---\n", "\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" ] }, { "cell_type": "markdown", "metadata": { "id": "6RYMhmfA9ATN", "colab_type": "text" }, "source": [ "### Setup\n", "Lightning is easy to install. Simply ```pip install pytorch-lightning```" ] }, { "cell_type": "code", "metadata": { "id": "lj2zD-wsbvGr", "colab_type": "code", "colab": {} }, "source": [ "! pip install pytorch-lightning --quiet" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "8g2mbvy-9xDI", "colab_type": "text" }, "source": [ "# Introduction\n", "\n", "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`" ] }, { "cell_type": "code", "metadata": { "id": "eg-xDlmDdAwy", "colab_type": "code", "colab": {} }, "source": [ "import pytorch_lightning as pl\n", "from pytorch_lightning.metrics.functional import accuracy\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import random_split, DataLoader\n", "\n", "# Note - you must have torchvision installed for this example\n", "from torchvision.datasets import MNIST, CIFAR10\n", "from torchvision import transforms" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "DzgY7wi88UuG", "colab_type": "text" }, "source": [ "## Defining the LitMNISTModel\n", "\n", "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", "\n", "Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n", "\n", "This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets." ] }, { "cell_type": "code", "metadata": { "id": "IQkW8_FF5nU2", "colab_type": "code", "colab": {} }, "source": [ "class LitMNIST(pl.LightningModule):\n", " \n", " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # We hardcode dataset specific stuff here.\n", " self.data_dir = data_dir\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Build model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", " self.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == 'fit' or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == 'test' or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=32)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=32)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=32)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "K7sg9KQd-QIO", "colab_type": "text" }, "source": [ "## Training the ListMNIST Model" ] }, { "cell_type": "code", "metadata": { "id": "QxDNDaus6byD", "colab_type": "code", "colab": {} }, "source": [ "model = LitMNIST()\n", "trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n", "trainer.fit(model)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dY8d6GxmB0YU", "colab_type": "text" }, "source": [ "# Using DataModules\n", "\n", "DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models." ] }, { "cell_type": "markdown", "metadata": { "id": "eJeT5bW081wn", "colab_type": "text" }, "source": [ "## Defining The MNISTDataModule\n", "\n", "Let's go over each function in the class below and talk about what they're doing:\n", "\n", "1. ```__init__```\n", " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n", " - Defines a transform that will be applied across train, val, and test dataset splits.\n", " - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n", "\n", "\n", "2. ```prepare_data```\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "3. ```setup```\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "\n", "4. ```x_dataloader```\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" ] }, { "cell_type": "code", "metadata": { "id": "DfGKyGwG_X9v", "colab_type": "code", "colab": {} }, "source": [ "class MNISTDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, data_dir: str = './'):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", " # self.dims is returned when you call dm.size()\n", " # Setting default dims here because we know them.\n", " # Could optionally be assigned dynamically in dm.setup()\n", " self.dims = (1, 28, 28)\n", " self.num_classes = 10\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == 'fit' or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == 'test' or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=32)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=32)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=32)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "H2Yoj-9M9dS7", "colab_type": "text" }, "source": [ "## Defining the dataset agnostic `LitModel`\n", "\n", "Below, we define the same model as the `LitMNIST` model we made earlier. \n", "\n", "However, this time our model has the freedom to use any input data that we'd like 🔥." ] }, { "cell_type": "code", "metadata": { "id": "PM2IISuOBDIu", "colab_type": "code", "colab": {} }, "source": [ "class LitModel(pl.LightningModule):\n", " \n", " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # We take in input dimensions as parameters and use those to dynamically build model.\n", " self.channels = channels\n", " self.width = width\n", " self.height = height\n", " self.num_classes = num_classes\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", "\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", " self.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "G4Z5olPe-xEo", "colab_type": "text" }, "source": [ "## Training the `LitModel` using the `MNISTDataModule`\n", "\n", "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders." ] }, { "cell_type": "code", "metadata": { "id": "kV48vP_9mEli", "colab_type": "code", "colab": {} }, "source": [ "# Init DataModule\n", "dm = MNISTDataModule()\n", "# Init model from datamodule's attributes\n", "model = LitModel(*dm.size(), dm.num_classes)\n", "# Init trainer\n", "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n", "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "WNxrugIGRRv5", "colab_type": "text" }, "source": [ "## Defining the CIFAR10 DataModule\n", "\n", "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset." ] }, { "cell_type": "code", "metadata": { "id": "1tkaYLU7RT5P", "colab_type": "code", "colab": {} }, "source": [ "class CIFAR10DataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, data_dir: str = './'):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", " ])\n", "\n", " self.dims = (3, 32, 32)\n", " self.num_classes = 10\n", "\n", " def prepare_data(self):\n", " # download\n", " CIFAR10(self.data_dir, train=True, download=True)\n", " CIFAR10(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == 'fit' or stage is None:\n", " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == 'test' or stage is None:\n", " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.cifar_train, batch_size=32)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.cifar_val, batch_size=32)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.cifar_test, batch_size=32)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BrXxf3oX_gsZ", "colab_type": "text" }, "source": [ "## Training the `LitModel` using the `CIFAR10DataModule`\n", "\n", "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n", "\n", "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data." ] }, { "cell_type": "code", "metadata": { "id": "sd-SbWi_krdj", "colab_type": "code", "colab": {} }, "source": [ "dm = CIFAR10DataModule()\n", "model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n", "trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] } ] }