diff --git a/docs/source/conf.py b/docs/source/conf.py
index 7a44acd730..58258af1fc 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -304,12 +304,12 @@ def setup(app):
# copy all notebooks to local folder
-path_nbs = os.path.join(PATH_HERE, 'notebooks')
-if not os.path.isdir(path_nbs):
- os.mkdir(path_nbs)
-for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')):
- path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
- shutil.copy(path_ipynb, path_ipynb2)
+# path_nbs = os.path.join(PATH_HERE, 'notebooks')
+# if not os.path.isdir(path_nbs):
+# os.mkdir(path_nbs)
+# for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')):
+# path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
+# shutil.copy(path_ipynb, path_ipynb2)
# Ignoring Third-party packages
diff --git a/notebooks/01-mnist-hello-world.ipynb b/notebooks/01-mnist-hello-world.ipynb
new file mode 100644
index 0000000000..c9e81cc990
--- /dev/null
+++ b/notebooks/01-mnist-hello-world.ipynb
@@ -0,0 +1,401 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "01-mnist-hello-world.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "i7XbLCXGkll9",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Introduction to Pytorch Lightning ⚡\n",
+ "\n",
+ "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
+ "\n",
+ "---\n",
+ " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
+ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2LODD6w9ixlT",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Setup \n",
+ "Lightning is easy to install. Simply ```pip install pytorch-lightning```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zK7-Gg69kMnG",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "! pip install pytorch-lightning --quiet"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "w4_TYnt_keJi",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import os\n",
+ "\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.nn import functional as F\n",
+ "from torch.utils.data import DataLoader, random_split\n",
+ "from torchvision.datasets import MNIST\n",
+ "from torchvision import transforms\n",
+ "import pytorch_lightning as pl\n",
+ "from pytorch_lightning.metrics.functional import accuracy"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EHpyMPKFkVbZ",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Simplest example\n",
+ "\n",
+ "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
+ "\n",
+ "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "V7ELesz1kVQo",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MNISTModel(pl.LightningModule):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super(MNISTModel, self).__init__()\n",
+ " self.l1 = torch.nn.Linear(28 * 28, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
+ "\n",
+ " def training_step(self, batch, batch_nb):\n",
+ " x, y = batch\n",
+ " loss = F.cross_entropy(self(x), y)\n",
+ " return pl.TrainResult(loss)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " return torch.optim.Adam(self.parameters(), lr=0.02)"
+ ],
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hIrtHg-Dv8TJ",
+ "colab_type": "text"
+ },
+ "source": [
+ "By using the `Trainer` you automatically get:\n",
+ "1. Tensorboard logging\n",
+ "2. Model checkpointing\n",
+ "3. Training and validation loop\n",
+ "4. early-stopping"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4Dk6Ykv8lI7X",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Init our model\n",
+ "mnist_model = MNISTModel()\n",
+ "\n",
+ "# Init DataLoader from MNIST Dataset\n",
+ "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
+ "train_loader = DataLoader(train_ds, batch_size=32)\n",
+ "\n",
+ "# Initialize a trainer\n",
+ "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
+ "\n",
+ "# Train the model ⚡\n",
+ "trainer.fit(mnist_model, train_loader)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KNpOoBeIjscS",
+ "colab_type": "text"
+ },
+ "source": [
+ "## A more complete MNIST Lightning Module Example\n",
+ "\n",
+ "That wasn't so hard was it?\n",
+ "\n",
+ "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
+ "\n",
+ "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### Note what the following built-in functions are doing:\n",
+ "\n",
+ "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
+ " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
+ " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
+ "\n",
+ "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
+ " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
+ " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
+ " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
+ " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
+ "\n",
+ "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
+ " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4DNItffri95Q",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class LitMNIST(pl.LightningModule):\n",
+ " \n",
+ " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
+ "\n",
+ " super().__init__()\n",
+ "\n",
+ " # Set our init args as class attributes\n",
+ " self.data_dir = data_dir\n",
+ " self.hidden_size = hidden_size\n",
+ " self.learning_rate = learning_rate\n",
+ "\n",
+ " # Hardcode some dataset specific attributes\n",
+ " self.num_classes = 10\n",
+ " self.dims = (1, 28, 28)\n",
+ " channels, width, height = self.dims\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ "\n",
+ " # Define PyTorch model\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(channels * width * height, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, self.num_classes)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.model(x)\n",
+ " return F.log_softmax(x, dim=1)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " return pl.TrainResult(loss)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = accuracy(preds, y)\n",
+ " result = pl.EvalResult(checkpoint_on=loss)\n",
+ "\n",
+ " # Calling result.log will surface up scalars for you in TensorBoard\n",
+ " result.log('val_loss', loss, prog_bar=True)\n",
+ " result.log('val_acc', acc, prog_bar=True)\n",
+ " return result\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " # Here we just reuse the validation_step for testing\n",
+ " return self.validation_step(batch, batch_idx)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " return optimizer\n",
+ "\n",
+ " ####################\n",
+ " # DATA RELATED HOOKS\n",
+ " ####################\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download\n",
+ " MNIST(self.data_dir, train=True, download=True)\n",
+ " MNIST(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # Assign train/val datasets for use in dataloaders\n",
+ " if stage == 'fit' or stage is None:\n",
+ " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
+ " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
+ "\n",
+ " # Assign test dataset for use in dataloader(s)\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.mnist_train, batch_size=32)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.mnist_val, batch_size=32)\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " return DataLoader(self.mnist_test, batch_size=32)"
+ ],
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Mb0U5Rk2kLBy",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "model = LitMNIST()\n",
+ "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
+ "trainer.fit(model)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nht8AvMptY6I",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Testing\n",
+ "\n",
+ "To test a model, call `trainer.test(model)`.\n",
+ "\n",
+ "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "PA151FkLtprO",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "trainer.test()"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "T3-3lbbNtr5T",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Bonus Tip\n",
+ "\n",
+ "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IFBwCbLet2r6",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "trainer.fit(model)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8TRyS5CCt3n9",
+ "colab_type": "text"
+ },
+ "source": [
+ "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "wizS-QiLuAYo",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Start tensorboard.\n",
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir lightning_logs/"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/notebooks/02-datamodules.ipynb b/notebooks/02-datamodules.ipynb
new file mode 100644
index 0000000000..53468d2c72
--- /dev/null
+++ b/notebooks/02-datamodules.ipynb
@@ -0,0 +1,542 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "02-datamodules.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true,
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2O5r7QvP8-rt",
+ "colab_type": "text"
+ },
+ "source": [
+ "# PyTorch Lightning DataModules ⚡\n",
+ "\n",
+ "With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n",
+ "\n",
+ "This notebook will walk you through how to start using Datamodules.\n",
+ "\n",
+ "The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n",
+ "\n",
+ "---\n",
+ "\n",
+ " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
+ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6RYMhmfA9ATN",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Setup\n",
+ "Lightning is easy to install. Simply ```pip install pytorch-lightning```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lj2zD-wsbvGr",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "! pip install pytorch-lightning --quiet"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8g2mbvy-9xDI",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Introduction\n",
+ "\n",
+ "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "eg-xDlmDdAwy",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import pytorch_lightning as pl\n",
+ "from pytorch_lightning.metrics.functional import accuracy\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import random_split, DataLoader\n",
+ "\n",
+ "# Note - you must have torchvision installed for this example\n",
+ "from torchvision.datasets import MNIST, CIFAR10\n",
+ "from torchvision import transforms"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DzgY7wi88UuG",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Defining the LitMNISTModel\n",
+ "\n",
+ "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n",
+ "\n",
+ "Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n",
+ "\n",
+ "This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IQkW8_FF5nU2",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class LitMNIST(pl.LightningModule):\n",
+ " \n",
+ " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
+ "\n",
+ " super().__init__()\n",
+ "\n",
+ " # We hardcode dataset specific stuff here.\n",
+ " self.data_dir = data_dir\n",
+ " self.num_classes = 10\n",
+ " self.dims = (1, 28, 28)\n",
+ " channels, width, height = self.dims\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ "\n",
+ " self.hidden_size = hidden_size\n",
+ " self.learning_rate = learning_rate\n",
+ "\n",
+ " # Build model\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(channels * width * height, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, self.num_classes)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.model(x)\n",
+ " return F.log_softmax(x, dim=1)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " return pl.TrainResult(loss)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = accuracy(preds, y)\n",
+ " result = pl.EvalResult(checkpoint_on=loss)\n",
+ " result.log('val_loss', loss, prog_bar=True)\n",
+ " result.log('val_acc', acc, prog_bar=True)\n",
+ " return result\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " return optimizer\n",
+ "\n",
+ " ####################\n",
+ " # DATA RELATED HOOKS\n",
+ " ####################\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download\n",
+ " MNIST(self.data_dir, train=True, download=True)\n",
+ " MNIST(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # Assign train/val datasets for use in dataloaders\n",
+ " if stage == 'fit' or stage is None:\n",
+ " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
+ " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
+ "\n",
+ " # Assign test dataset for use in dataloader(s)\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.mnist_train, batch_size=32)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.mnist_val, batch_size=32)\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " return DataLoader(self.mnist_test, batch_size=32)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "K7sg9KQd-QIO",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Training the ListMNIST Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "QxDNDaus6byD",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "model = LitMNIST()\n",
+ "trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n",
+ "trainer.fit(model)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dY8d6GxmB0YU",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Using DataModules\n",
+ "\n",
+ "DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eJeT5bW081wn",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Defining The MNISTDataModule\n",
+ "\n",
+ "Let's go over each function in the class below and talk about what they're doing:\n",
+ "\n",
+ "1. ```__init__```\n",
+ " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n",
+ " - Defines a transform that will be applied across train, val, and test dataset splits.\n",
+ " - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n",
+ "\n",
+ "\n",
+ "2. ```prepare_data```\n",
+ " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
+ " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
+ "\n",
+ "3. ```setup```\n",
+ " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
+ " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
+ " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n",
+ " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
+ "\n",
+ "\n",
+ "4. ```x_dataloader```\n",
+ " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "DfGKyGwG_X9v",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MNISTDataModule(pl.LightningDataModule):\n",
+ "\n",
+ " def __init__(self, data_dir: str = './'):\n",
+ " super().__init__()\n",
+ " self.data_dir = data_dir\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ "\n",
+ " # self.dims is returned when you call dm.size()\n",
+ " # Setting default dims here because we know them.\n",
+ " # Could optionally be assigned dynamically in dm.setup()\n",
+ " self.dims = (1, 28, 28)\n",
+ " self.num_classes = 10\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download\n",
+ " MNIST(self.data_dir, train=True, download=True)\n",
+ " MNIST(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # Assign train/val datasets for use in dataloaders\n",
+ " if stage == 'fit' or stage is None:\n",
+ " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
+ " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
+ "\n",
+ " # Assign test dataset for use in dataloader(s)\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.mnist_train, batch_size=32)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.mnist_val, batch_size=32)\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " return DataLoader(self.mnist_test, batch_size=32)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "H2Yoj-9M9dS7",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Defining the dataset agnostic `LitModel`\n",
+ "\n",
+ "Below, we define the same model as the `LitMNIST` model we made earlier. \n",
+ "\n",
+ "However, this time our model has the freedom to use any input data that we'd like 🔥."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "PM2IISuOBDIu",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class LitModel(pl.LightningModule):\n",
+ " \n",
+ " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
+ "\n",
+ " super().__init__()\n",
+ "\n",
+ " # We take in input dimensions as parameters and use those to dynamically build model.\n",
+ " self.channels = channels\n",
+ " self.width = width\n",
+ " self.height = height\n",
+ " self.num_classes = num_classes\n",
+ " self.hidden_size = hidden_size\n",
+ " self.learning_rate = learning_rate\n",
+ "\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(channels * width * height, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, hidden_size),\n",
+ " nn.ReLU(),\n",
+ " nn.Dropout(0.1),\n",
+ " nn.Linear(hidden_size, num_classes)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.model(x)\n",
+ " return F.log_softmax(x, dim=1)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " return pl.TrainResult(loss)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ "\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = accuracy(preds, y)\n",
+ " result = pl.EvalResult(checkpoint_on=loss)\n",
+ " result.log('val_loss', loss, prog_bar=True)\n",
+ " result.log('val_acc', acc, prog_bar=True)\n",
+ " return result\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " return optimizer"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G4Z5olPe-xEo",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Training the `LitModel` using the `MNISTDataModule`\n",
+ "\n",
+ "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kV48vP_9mEli",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Init DataModule\n",
+ "dm = MNISTDataModule()\n",
+ "# Init model from datamodule's attributes\n",
+ "model = LitModel(*dm.size(), dm.num_classes)\n",
+ "# Init trainer\n",
+ "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n",
+ "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WNxrugIGRRv5",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Defining the CIFAR10 DataModule\n",
+ "\n",
+ "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1tkaYLU7RT5P",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class CIFAR10DataModule(pl.LightningDataModule):\n",
+ "\n",
+ " def __init__(self, data_dir: str = './'):\n",
+ " super().__init__()\n",
+ " self.data_dir = data_dir\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
+ " ])\n",
+ "\n",
+ " self.dims = (3, 32, 32)\n",
+ " self.num_classes = 10\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download\n",
+ " CIFAR10(self.data_dir, train=True, download=True)\n",
+ " CIFAR10(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # Assign train/val datasets for use in dataloaders\n",
+ " if stage == 'fit' or stage is None:\n",
+ " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n",
+ " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n",
+ "\n",
+ " # Assign test dataset for use in dataloader(s)\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.cifar_train, batch_size=32)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.cifar_val, batch_size=32)\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " return DataLoader(self.cifar_test, batch_size=32)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BrXxf3oX_gsZ",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Training the `LitModel` using the `CIFAR10DataModule`\n",
+ "\n",
+ "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n",
+ "\n",
+ "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sd-SbWi_krdj",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "dm = CIFAR10DataModule()\n",
+ "model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n",
+ "trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/notebooks/03-basic-gan.ipynb b/notebooks/03-basic-gan.ipynb
new file mode 100644
index 0000000000..d88a524285
--- /dev/null
+++ b/notebooks/03-basic-gan.ipynb
@@ -0,0 +1,424 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "03-basic-gan.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "J37PBnE_x7IW",
+ "colab_type": "text"
+ },
+ "source": [
+ "# PyTorch Lightning Basic GAN Tutorial ⚡\n",
+ "\n",
+ "How to train a GAN!\n",
+ "\n",
+ "Main takeaways:\n",
+ "1. Generator and discriminator are arbitraty PyTorch modules.\n",
+ "2. training_step does both the generator and discriminator training.\n",
+ "\n",
+ "---\n",
+ "\n",
+ " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
+ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kg2MKpRmybht",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Setup\n",
+ "Lightning is easy to install. Simply `pip install pytorch-lightning`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LfrJLKPFyhsK",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "! pip install pytorch-lightning --quiet"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "BjEPuiVLyanw",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import os\n",
+ "from argparse import ArgumentParser\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torchvision\n",
+ "import torchvision.transforms as transforms\n",
+ "from torch.utils.data import DataLoader, random_split\n",
+ "from torchvision.datasets import MNIST\n",
+ "\n",
+ "import pytorch_lightning as pl"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OuXJzr4G2uHV",
+ "colab_type": "text"
+ },
+ "source": [
+ "### MNIST DataModule\n",
+ "\n",
+ "Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "DOY_nHu328g7",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MNISTDataModule(pl.LightningDataModule):\n",
+ "\n",
+ " def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
+ " super().__init__()\n",
+ " self.data_dir = data_dir\n",
+ " self.batch_size = batch_size\n",
+ " self.num_workers = num_workers\n",
+ "\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ "\n",
+ " # self.dims is returned when you call dm.size()\n",
+ " # Setting default dims here because we know them.\n",
+ " # Could optionally be assigned dynamically in dm.setup()\n",
+ " self.dims = (1, 28, 28)\n",
+ " self.num_classes = 10\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " # download\n",
+ " MNIST(self.data_dir, train=True, download=True)\n",
+ " MNIST(self.data_dir, train=False, download=True)\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ "\n",
+ " # Assign train/val datasets for use in dataloaders\n",
+ " if stage == 'fit' or stage is None:\n",
+ " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
+ " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
+ "\n",
+ " # Assign test dataset for use in dataloader(s)\n",
+ " if stage == 'test' or stage is None:\n",
+ " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
+ ],
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tW3c0QrQyF9P",
+ "colab_type": "text"
+ },
+ "source": [
+ "### A. Generator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0E2QDjl5yWtz",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class Generator(nn.Module):\n",
+ " def __init__(self, latent_dim, img_shape):\n",
+ " super().__init__()\n",
+ " self.img_shape = img_shape\n",
+ "\n",
+ " def block(in_feat, out_feat, normalize=True):\n",
+ " layers = [nn.Linear(in_feat, out_feat)]\n",
+ " if normalize:\n",
+ " layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
+ " layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
+ " return layers\n",
+ "\n",
+ " self.model = nn.Sequential(\n",
+ " *block(latent_dim, 128, normalize=False),\n",
+ " *block(128, 256),\n",
+ " *block(256, 512),\n",
+ " *block(512, 1024),\n",
+ " nn.Linear(1024, int(np.prod(img_shape))),\n",
+ " nn.Tanh()\n",
+ " )\n",
+ "\n",
+ " def forward(self, z):\n",
+ " img = self.model(z)\n",
+ " img = img.view(img.size(0), *self.img_shape)\n",
+ " return img"
+ ],
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uyrltsGvyaI3",
+ "colab_type": "text"
+ },
+ "source": [
+ "### B. Discriminator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Ed3MR3vnyxyW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class Discriminator(nn.Module):\n",
+ " def __init__(self, img_shape):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Linear(int(np.prod(img_shape)), 512),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Linear(512, 256),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Linear(256, 1),\n",
+ " nn.Sigmoid(),\n",
+ " )\n",
+ "\n",
+ " def forward(self, img):\n",
+ " img_flat = img.view(img.size(0), -1)\n",
+ " validity = self.model(img_flat)\n",
+ "\n",
+ " return validity"
+ ],
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BwUMom3ryySK",
+ "colab_type": "text"
+ },
+ "source": [
+ "### C. GAN\n",
+ "\n",
+ "#### A couple of cool features to check out in this example...\n",
+ "\n",
+ " - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
+ " - Lightning will put your dataloader data on the right device automatically\n",
+ " - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
+ " - `type_as` is the way we recommend to do this.\n",
+ " - This example shows how to use multiple dataloaders in your `LightningModule`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3vKszYf6y1Vv",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ " class GAN(pl.LightningModule):\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " channels,\n",
+ " width,\n",
+ " height,\n",
+ " latent_dim: int = 100,\n",
+ " lr: float = 0.0002,\n",
+ " b1: float = 0.5,\n",
+ " b2: float = 0.999,\n",
+ " batch_size: int = 64,\n",
+ " **kwargs\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " # networks\n",
+ " data_shape = (channels, width, height)\n",
+ " self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
+ " self.discriminator = Discriminator(img_shape=data_shape)\n",
+ "\n",
+ " self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
+ "\n",
+ " self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
+ "\n",
+ " def forward(self, z):\n",
+ " return self.generator(z)\n",
+ "\n",
+ " def adversarial_loss(self, y_hat, y):\n",
+ " return F.binary_cross_entropy(y_hat, y)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx, optimizer_idx):\n",
+ " imgs, _ = batch\n",
+ "\n",
+ " # sample noise\n",
+ " z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
+ " z = z.type_as(imgs)\n",
+ "\n",
+ " # train generator\n",
+ " if optimizer_idx == 0:\n",
+ "\n",
+ " # generate images\n",
+ " self.generated_imgs = self(z)\n",
+ "\n",
+ " # log sampled images\n",
+ " sample_imgs = self.generated_imgs[:6]\n",
+ " grid = torchvision.utils.make_grid(sample_imgs)\n",
+ " self.logger.experiment.add_image('generated_images', grid, 0)\n",
+ "\n",
+ " # ground truth result (ie: all fake)\n",
+ " # put on GPU because we created this tensor inside training_loop\n",
+ " valid = torch.ones(imgs.size(0), 1)\n",
+ " valid = valid.type_as(imgs)\n",
+ "\n",
+ " # adversarial loss is binary cross-entropy\n",
+ " g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
+ " tqdm_dict = {'g_loss': g_loss}\n",
+ " output = OrderedDict({\n",
+ " 'loss': g_loss,\n",
+ " 'progress_bar': tqdm_dict,\n",
+ " 'log': tqdm_dict\n",
+ " })\n",
+ " return output\n",
+ "\n",
+ " # train discriminator\n",
+ " if optimizer_idx == 1:\n",
+ " # Measure discriminator's ability to classify real from generated samples\n",
+ "\n",
+ " # how well can it label as real?\n",
+ " valid = torch.ones(imgs.size(0), 1)\n",
+ " valid = valid.type_as(imgs)\n",
+ "\n",
+ " real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
+ "\n",
+ " # how well can it label as fake?\n",
+ " fake = torch.zeros(imgs.size(0), 1)\n",
+ " fake = fake.type_as(imgs)\n",
+ "\n",
+ " fake_loss = self.adversarial_loss(\n",
+ " self.discriminator(self(z).detach()), fake)\n",
+ "\n",
+ " # discriminator loss is the average of these\n",
+ " d_loss = (real_loss + fake_loss) / 2\n",
+ " tqdm_dict = {'d_loss': d_loss}\n",
+ " output = OrderedDict({\n",
+ " 'loss': d_loss,\n",
+ " 'progress_bar': tqdm_dict,\n",
+ " 'log': tqdm_dict\n",
+ " })\n",
+ " return output\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " lr = self.hparams.lr\n",
+ " b1 = self.hparams.b1\n",
+ " b2 = self.hparams.b2\n",
+ "\n",
+ " opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
+ " opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
+ " return [opt_g, opt_d], []\n",
+ "\n",
+ " def on_epoch_end(self):\n",
+ " z = self.validation_z.type_as(self.generator.model[0].weight)\n",
+ "\n",
+ " # log sampled images\n",
+ " sample_imgs = self(z)\n",
+ " grid = torchvision.utils.make_grid(sample_imgs)\n",
+ " self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
+ ],
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Ey5FmJPnzm_E",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "dm = MNISTDataModule()\n",
+ "model = GAN(*dm.size())\n",
+ "trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "MlECc7cHzolp",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Start tensorboard.\n",
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir lightning_logs/"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb
new file mode 100644
index 0000000000..d2649c1a8d
--- /dev/null
+++ b/notebooks/04-transformers-text-classification.ipynb
@@ -0,0 +1,556 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "04-transformers-text-classification.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true,
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8ag5ANQPJ_j9",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
+ "\n",
+ "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
+ "\n",
+ "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
+ "\n",
+ "---\n",
+ " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
+ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
+ "\n",
+ " - [HuggingFace nlp](https://github.com/huggingface/nlp)\n",
+ " - [HuggingFace transformers](https://github.com/huggingface/transformers)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fqlsVTj7McZ3",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "OIhHrRL-MnKK",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "!pip install pytorch-lightning datasets transformers"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6yuQT_ZQMpCg",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from argparse import ArgumentParser\n",
+ "from datetime import datetime\n",
+ "from typing import Optional\n",
+ "\n",
+ "import nlp\n",
+ "import numpy as np\n",
+ "import pytorch_lightning as pl\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "from transformers import (\n",
+ " AdamW,\n",
+ " AutoModelForSequenceClassification,\n",
+ " AutoConfig,\n",
+ " AutoTokenizer,\n",
+ " get_linear_schedule_with_warmup,\n",
+ " glue_compute_metrics\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9ORJfiuiNZ_N",
+ "colab_type": "text"
+ },
+ "source": [
+ "## GLUE DataModule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "jW9xQhZxMz1G",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class GLUEDataModule(pl.LightningDataModule):\n",
+ "\n",
+ " task_text_field_map = {\n",
+ " 'cola': ['sentence'],\n",
+ " 'sst2': ['sentence'],\n",
+ " 'mrpc': ['sentence1', 'sentence2'],\n",
+ " 'qqp': ['question1', 'question2'],\n",
+ " 'stsb': ['sentence1', 'sentence2'],\n",
+ " 'mnli': ['premise', 'hypothesis'],\n",
+ " 'qnli': ['question', 'sentence'],\n",
+ " 'rte': ['sentence1', 'sentence2'],\n",
+ " 'wnli': ['sentence1', 'sentence2'],\n",
+ " 'ax': ['premise', 'hypothesis']\n",
+ " }\n",
+ "\n",
+ " glue_task_num_labels = {\n",
+ " 'cola': 2,\n",
+ " 'sst2': 2,\n",
+ " 'mrpc': 2,\n",
+ " 'qqp': 2,\n",
+ " 'stsb': 1,\n",
+ " 'mnli': 3,\n",
+ " 'qnli': 2,\n",
+ " 'rte': 2,\n",
+ " 'wnli': 2,\n",
+ " 'ax': 3\n",
+ " }\n",
+ "\n",
+ " loader_columns = [\n",
+ " 'nlp_idx',\n",
+ " 'input_ids',\n",
+ " 'token_type_ids',\n",
+ " 'attention_mask',\n",
+ " 'start_positions',\n",
+ " 'end_positions',\n",
+ " 'labels'\n",
+ " ]\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name_or_path: str,\n",
+ " task_name: str ='mrpc',\n",
+ " max_seq_length: int = 128,\n",
+ " train_batch_size: int = 32,\n",
+ " eval_batch_size: int = 32,\n",
+ " **kwargs\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.model_name_or_path = model_name_or_path\n",
+ " self.task_name = task_name\n",
+ " self.max_seq_length = max_seq_length\n",
+ " self.train_batch_size = train_batch_size\n",
+ " self.eval_batch_size = eval_batch_size\n",
+ "\n",
+ " self.text_fields = self.task_text_field_map[task_name]\n",
+ " self.num_labels = self.glue_task_num_labels[task_name]\n",
+ " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
+ "\n",
+ " def setup(self, stage):\n",
+ " self.dataset = nlp.load_dataset('glue', self.task_name)\n",
+ "\n",
+ " for split in self.dataset.keys():\n",
+ " self.dataset[split] = self.dataset[split].map(\n",
+ " self.convert_to_features,\n",
+ " batched=True,\n",
+ " remove_columns=['label'],\n",
+ " )\n",
+ " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
+ " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
+ "\n",
+ " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " nlp.load_dataset('glue', self.task_name)\n",
+ " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
+ " \n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
+ " \n",
+ " def val_dataloader(self):\n",
+ " if len(self.eval_splits) == 1:\n",
+ " return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
+ " elif len(self.eval_splits) > 1:\n",
+ " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
+ "\n",
+ " def test_dataloader(self):\n",
+ " if len(self.eval_splits) == 1:\n",
+ " return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
+ " elif len(self.eval_splits) > 1:\n",
+ " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
+ "\n",
+ " def convert_to_features(self, example_batch, indices=None):\n",
+ "\n",
+ " # Either encode single sentence or sentence pairs\n",
+ " if len(self.text_fields) > 1:\n",
+ " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
+ " else:\n",
+ " texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
+ "\n",
+ " # Tokenize the text/text pairs\n",
+ " features = self.tokenizer.batch_encode_plus(\n",
+ " texts_or_text_pairs,\n",
+ " max_length=self.max_seq_length,\n",
+ " pad_to_max_length=True,\n",
+ " truncation=True\n",
+ " )\n",
+ "\n",
+ " # Rename label to labels to make it easier to pass to model forward\n",
+ " features['labels'] = example_batch['label']\n",
+ "\n",
+ " return features"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jQC3a6KuOpX3",
+ "colab_type": "text"
+ },
+ "source": [
+ "#### You could use this datamodule with standalone PyTorch if you wanted..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "JCMH3IAsNffF",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "dm = GLUEDataModule('distilbert-base-uncased')\n",
+ "dm.prepare_data()\n",
+ "dm.setup('fit')\n",
+ "next(iter(dm.train_dataloader()))"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l9fQ_67BO2Lj",
+ "colab_type": "text"
+ },
+ "source": [
+ "## GLUE Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "gtn5YGKYO65B",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class GLUETransformer(pl.LightningModule):\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name_or_path: str,\n",
+ " num_labels: int,\n",
+ " learning_rate: float = 2e-5,\n",
+ " adam_epsilon: float = 1e-8,\n",
+ " warmup_steps: int = 0,\n",
+ " weight_decay: float = 0.0,\n",
+ " train_batch_size: int = 32,\n",
+ " eval_batch_size: int = 32,\n",
+ " eval_splits: Optional[list] = None,\n",
+ " **kwargs\n",
+ " ):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
+ " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
+ " self.metric = nlp.load_metric(\n",
+ " 'glue',\n",
+ " self.hparams.task_name,\n",
+ " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
+ " )\n",
+ "\n",
+ " def forward(self, **inputs):\n",
+ " return self.model(**inputs)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " outputs = self(**batch)\n",
+ " loss = outputs[0]\n",
+ " return pl.TrainResult(loss)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
+ " outputs = self(**batch)\n",
+ " val_loss, logits = outputs[:2]\n",
+ "\n",
+ " if self.hparams.num_labels >= 1:\n",
+ " preds = torch.argmax(logits, axis=1)\n",
+ " elif self.hparams.num_labels == 1:\n",
+ " preds = logits.squeeze()\n",
+ "\n",
+ " labels = batch[\"labels\"]\n",
+ "\n",
+ " return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
+ "\n",
+ " def validation_epoch_end(self, outputs):\n",
+ " if self.hparams.task_name == 'mnli':\n",
+ " for i, output in enumerate(outputs):\n",
+ " # matched or mismatched\n",
+ " split = self.hparams.eval_splits[i].split('_')[-1]\n",
+ " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
+ " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
+ " loss = torch.stack([x['loss'] for x in output]).mean()\n",
+ " if i == 0:\n",
+ " result = pl.EvalResult(checkpoint_on=loss)\n",
+ " result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
+ " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(preds, labels).items()}\n",
+ " result.log_dict(split_metrics, prog_bar=True)\n",
+ " return result\n",
+ "\n",
+ " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
+ " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
+ " loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
+ " result = pl.EvalResult(checkpoint_on=loss)\n",
+ " result.log('val_loss', loss, prog_bar=True)\n",
+ " result.log_dict(self.metric.compute(preds, labels), prog_bar=True)\n",
+ " return result\n",
+ "\n",
+ " def setup(self, stage):\n",
+ " if stage == 'fit':\n",
+ " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
+ " train_loader = self.train_dataloader()\n",
+ "\n",
+ " # Calculate total steps\n",
+ " self.total_steps = (\n",
+ " (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
+ " // self.hparams.accumulate_grad_batches\n",
+ " * float(self.hparams.max_epochs)\n",
+ " )\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
+ " model = self.model\n",
+ " no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
+ " optimizer_grouped_parameters = [\n",
+ " {\n",
+ " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
+ " \"weight_decay\": self.hparams.weight_decay,\n",
+ " },\n",
+ " {\n",
+ " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
+ " \"weight_decay\": 0.0,\n",
+ " },\n",
+ " ]\n",
+ " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
+ "\n",
+ " scheduler = get_linear_schedule_with_warmup(\n",
+ " optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
+ " )\n",
+ " scheduler = {\n",
+ " 'scheduler': scheduler,\n",
+ " 'interval': 'step',\n",
+ " 'frequency': 1\n",
+ " }\n",
+ " return [optimizer], [scheduler]\n",
+ "\n",
+ " @staticmethod\n",
+ " def add_model_specific_args(parent_parser):\n",
+ " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
+ " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
+ " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
+ " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
+ " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
+ " return parser"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ha-NdIP_xbd3",
+ "colab_type": "text"
+ },
+ "source": [
+ "### ⚡ Quick Tip \n",
+ " - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3dEHnl3RPlAR",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def parse_args(args=None):\n",
+ " parser = ArgumentParser()\n",
+ " parser = pl.Trainer.add_argparse_args(parser)\n",
+ " parser = GLUEDataModule.add_argparse_args(parser)\n",
+ " parser = GLUETransformer.add_model_specific_args(parser)\n",
+ " parser.add_argument('--seed', type=int, default=42)\n",
+ " return parser.parse_args(args)\n",
+ "\n",
+ "\n",
+ "def main(args):\n",
+ " pl.seed_everything(args.seed)\n",
+ " dm = GLUEDataModule.from_argparse_args(args)\n",
+ " dm.prepare_data()\n",
+ " dm.setup('fit')\n",
+ " model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
+ " trainer = pl.Trainer.from_argparse_args(args)\n",
+ " return dm, model, trainer"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PkuLaeec3sJ-",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QSpueK5UPsN7",
+ "colab_type": "text"
+ },
+ "source": [
+ "## CoLA\n",
+ "\n",
+ "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "NJnFmtpnPu0Y",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "mocked_args = \"\"\"\n",
+ " --model_name_or_path albert-base-v2\n",
+ " --task_name cola\n",
+ " --max_epochs 3\n",
+ " --gpus 1\"\"\".split()\n",
+ "\n",
+ "args = parse_args(mocked_args)\n",
+ "dm, model, trainer = main(args)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_MrNsTnqdz4z",
+ "colab_type": "text"
+ },
+ "source": [
+ "## MRPC\n",
+ "\n",
+ "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LBwRxg9Cb3d-",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "mocked_args = \"\"\"\n",
+ " --model_name_or_path distilbert-base-cased\n",
+ " --task_name mrpc\n",
+ " --max_epochs 3\n",
+ " --gpus 1\"\"\".split()\n",
+ "\n",
+ "args = parse_args(mocked_args)\n",
+ "dm, model, trainer = main(args)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iZhbn0HzfdCu",
+ "colab_type": "text"
+ },
+ "source": [
+ "## MNLI\n",
+ "\n",
+ " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
+ "\n",
+ " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
+ "\n",
+ "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "AvsZMOggfcWW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "mocked_args = \"\"\"\n",
+ " --model_name_or_path distilbert-base-uncased\n",
+ " --task_name mnli\n",
+ " --max_epochs 1\n",
+ " --gpus 1\n",
+ " --limit_train_batches 10\n",
+ " --progress_bar_refresh_rate 20\"\"\".split()\n",
+ "\n",
+ "args = parse_args(mocked_args)\n",
+ "dm, model, trainer = main(args)\n",
+ "trainer.fit(model, dm)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/notebooks/README.md b/notebooks/README.md
new file mode 100644
index 0000000000..1f946ee196
--- /dev/null
+++ b/notebooks/README.md
@@ -0,0 +1,12 @@
+# Lightning Notebooks ⚡
+
+## Official Notebooks
+
+You can easily run any of the official notebooks by clicking the 'Open in Colab' links in the table below :smile:
+
+| Notebook | Description | Colab Link |
+| :--- | :--- | :---: |
+| __MNIST Hello World__ | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01_mnist_hello_world.ipynb) |
+| __Datamodules__ | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10.| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02_datamodules.ipynb)|
+| __GAN__ | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03_basic_gan.ipynb) |
+| __BERT__ | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04_transformers_text_classification.ipynb) |