{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "i7XbLCXGkll9" }, "source": [ "# Introduction to Pytorch Lightning ⚡\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2LODD6w9ixlT" }, "source": [ "### Setup \n", "Lightning is easy to install. Simply ```pip install pytorch-lightning```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "zK7-Gg69kMnG" }, "outputs": [], "source": [ "! pip install pytorch-lightning --quiet" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "w4_TYnt_keJi" }, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.metrics.functional import accuracy" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EHpyMPKFkVbZ" }, "source": [ "## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "V7ELesz1kVQo" }, "outputs": [], "source": [ "class MNISTModel(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super(MNISTModel, self).__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x):\n", " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", "\n", " def training_step(self, batch, batch_nb):\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=0.02)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hIrtHg-Dv8TJ" }, "source": [ "By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "4Dk6Ykv8lI7X" }, "outputs": [], "source": [ "# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", "train_loader = DataLoader(train_ds, batch_size=32)\n", "\n", "# Initialize a trainer\n", "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", "\n", "# Train the model ⚡\n", "trainer.fit(mnist_model, train_loader)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KNpOoBeIjscS" }, "source": [ "## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#setup) ⚙️\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#data-hooks) ♻️\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "4DNItffri95Q" }, "outputs": [], "source": [ "class LitMNIST(pl.LightningModule):\n", " \n", " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n", " return loss\n", "\n", " def test_step(self, batch, batch_idx):\n", " # Here we just reuse the validation_step for testing\n", " return self.validation_step(batch, batch_idx)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == 'fit' or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == 'test' or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=32)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=32)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Mb0U5Rk2kLBy" }, "outputs": [], "source": [ "model = LitMNIST()\n", "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", "trainer.fit(model)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nht8AvMptY6I" }, "source": [ "### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "PA151FkLtprO" }, "outputs": [], "source": [ "trainer.test()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "T3-3lbbNtr5T" }, "source": [ "### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "IFBwCbLet2r6" }, "outputs": [], "source": [ "trainer.fit(model)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8TRyS5CCt3n9" }, "source": [ "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "wizS-QiLuAYo" }, "outputs": [], "source": [ "# Start tensorboard.\n", "%load_ext tensorboard\n", "%tensorboard --logdir lightning_logs/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "

