add lightning colab tutorial notebooks (#3591)
* ⚡ add lightning colab tutorial notebooks * Update notebooks/01-mnist-hello-world.ipynb * Update notebooks/03-basic-gan.ipynb * Update notebooks/04-transformers-text-classification.ipynb * Update notebooks/02-datamodules.ipynb * add jupytext * 🐛 comment ipynb related things in conf.py * 🔥 remove .py files Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
parent
674d33061f
commit
6685b49db3
|
@ -304,12 +304,12 @@ def setup(app):
|
|||
|
||||
|
||||
# copy all notebooks to local folder
|
||||
path_nbs = os.path.join(PATH_HERE, 'notebooks')
|
||||
if not os.path.isdir(path_nbs):
|
||||
os.mkdir(path_nbs)
|
||||
for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')):
|
||||
path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
|
||||
shutil.copy(path_ipynb, path_ipynb2)
|
||||
# path_nbs = os.path.join(PATH_HERE, 'notebooks')
|
||||
# if not os.path.isdir(path_nbs):
|
||||
# os.mkdir(path_nbs)
|
||||
# for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')):
|
||||
# path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
|
||||
# shutil.copy(path_ipynb, path_ipynb2)
|
||||
|
||||
|
||||
# Ignoring Third-party packages
|
||||
|
|
|
@ -0,0 +1,401 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "01-mnist-hello-world.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "i7XbLCXGkll9",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Introduction to Pytorch Lightning ⚡\n",
|
||||
"\n",
|
||||
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2LODD6w9ixlT",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup \n",
|
||||
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "zK7-Gg69kMnG",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "w4_TYnt_keJi",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.nn import functional as F\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision import transforms\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.metrics.functional import accuracy"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "EHpyMPKFkVbZ",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Simplest example\n",
|
||||
"\n",
|
||||
"Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
|
||||
"\n",
|
||||
"**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "V7ELesz1kVQo",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class MNISTModel(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MNISTModel, self).__init__()\n",
|
||||
" self.l1 = torch.nn.Linear(28 * 28, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_nb):\n",
|
||||
" x, y = batch\n",
|
||||
" loss = F.cross_entropy(self(x), y)\n",
|
||||
" return pl.TrainResult(loss)\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" return torch.optim.Adam(self.parameters(), lr=0.02)"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hIrtHg-Dv8TJ",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"By using the `Trainer` you automatically get:\n",
|
||||
"1. Tensorboard logging\n",
|
||||
"2. Model checkpointing\n",
|
||||
"3. Training and validation loop\n",
|
||||
"4. early-stopping"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "4Dk6Ykv8lI7X",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Init our model\n",
|
||||
"mnist_model = MNISTModel()\n",
|
||||
"\n",
|
||||
"# Init DataLoader from MNIST Dataset\n",
|
||||
"train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
|
||||
"train_loader = DataLoader(train_ds, batch_size=32)\n",
|
||||
"\n",
|
||||
"# Initialize a trainer\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"\n",
|
||||
"# Train the model ⚡\n",
|
||||
"trainer.fit(mnist_model, train_loader)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "KNpOoBeIjscS",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## A more complete MNIST Lightning Module Example\n",
|
||||
"\n",
|
||||
"That wasn't so hard was it?\n",
|
||||
"\n",
|
||||
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
|
||||
"\n",
|
||||
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Note what the following built-in functions are doing:\n",
|
||||
"\n",
|
||||
"1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
|
||||
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
|
||||
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
|
||||
"\n",
|
||||
"2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
|
||||
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
|
||||
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
|
||||
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
|
||||
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
|
||||
"\n",
|
||||
"3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
|
||||
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "4DNItffri95Q",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class LitMNIST(pl.LightningModule):\n",
|
||||
" \n",
|
||||
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # Set our init args as class attributes\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
"\n",
|
||||
" # Hardcode some dataset specific attributes\n",
|
||||
" self.num_classes = 10\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" channels, width, height = self.dims\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # Define PyTorch model\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(channels * width * height, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, self.num_classes)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.model(x)\n",
|
||||
" return F.log_softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" return pl.TrainResult(loss)\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" acc = accuracy(preds, y)\n",
|
||||
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||||
"\n",
|
||||
" # Calling result.log will surface up scalars for you in TensorBoard\n",
|
||||
" result.log('val_loss', loss, prog_bar=True)\n",
|
||||
" result.log('val_acc', acc, prog_bar=True)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def test_step(self, batch, batch_idx):\n",
|
||||
" # Here we just reuse the validation_step for testing\n",
|
||||
" return self.validation_step(batch, batch_idx)\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||||
" return optimizer\n",
|
||||
"\n",
|
||||
" ####################\n",
|
||||
" # DATA RELATED HOOKS\n",
|
||||
" ####################\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=32)"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Mb0U5Rk2kLBy",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"model = LitMNIST()\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nht8AvMptY6I",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Testing\n",
|
||||
"\n",
|
||||
"To test a model, call `trainer.test(model)`.\n",
|
||||
"\n",
|
||||
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PA151FkLtprO",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"trainer.test()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "T3-3lbbNtr5T",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Bonus Tip\n",
|
||||
"\n",
|
||||
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "IFBwCbLet2r6",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"trainer.fit(model)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8TRyS5CCt3n9",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "wizS-QiLuAYo",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,542 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "02-datamodules.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"toc_visible": true,
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2O5r7QvP8-rt",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# PyTorch Lightning DataModules ⚡\n",
|
||||
"\n",
|
||||
"With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n",
|
||||
"\n",
|
||||
"This notebook will walk you through how to start using Datamodules.\n",
|
||||
"\n",
|
||||
"The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "6RYMhmfA9ATN",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "lj2zD-wsbvGr",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8g2mbvy-9xDI",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Introduction\n",
|
||||
"\n",
|
||||
"First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "eg-xDlmDdAwy",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.metrics.functional import accuracy\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import random_split, DataLoader\n",
|
||||
"\n",
|
||||
"# Note - you must have torchvision installed for this example\n",
|
||||
"from torchvision.datasets import MNIST, CIFAR10\n",
|
||||
"from torchvision import transforms"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "DzgY7wi88UuG",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Defining the LitMNISTModel\n",
|
||||
"\n",
|
||||
"Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n",
|
||||
"\n",
|
||||
"Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n",
|
||||
"\n",
|
||||
"This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "IQkW8_FF5nU2",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class LitMNIST(pl.LightningModule):\n",
|
||||
" \n",
|
||||
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # We hardcode dataset specific stuff here.\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.num_classes = 10\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" channels, width, height = self.dims\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
"\n",
|
||||
" # Build model\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(channels * width * height, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, self.num_classes)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.model(x)\n",
|
||||
" return F.log_softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" return pl.TrainResult(loss)\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" acc = accuracy(preds, y)\n",
|
||||
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||||
" result.log('val_loss', loss, prog_bar=True)\n",
|
||||
" result.log('val_acc', acc, prog_bar=True)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||||
" return optimizer\n",
|
||||
"\n",
|
||||
" ####################\n",
|
||||
" # DATA RELATED HOOKS\n",
|
||||
" ####################\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=32)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "K7sg9KQd-QIO",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Training the ListMNIST Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "QxDNDaus6byD",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"model = LitMNIST()\n",
|
||||
"trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dY8d6GxmB0YU",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Using DataModules\n",
|
||||
"\n",
|
||||
"DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eJeT5bW081wn",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Defining The MNISTDataModule\n",
|
||||
"\n",
|
||||
"Let's go over each function in the class below and talk about what they're doing:\n",
|
||||
"\n",
|
||||
"1. ```__init__```\n",
|
||||
" - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n",
|
||||
" - Defines a transform that will be applied across train, val, and test dataset splits.\n",
|
||||
" - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"2. ```prepare_data```\n",
|
||||
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
|
||||
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
|
||||
"\n",
|
||||
"3. ```setup```\n",
|
||||
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
|
||||
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
|
||||
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n",
|
||||
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"4. ```x_dataloader```\n",
|
||||
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "DfGKyGwG_X9v",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class MNISTDataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_dir: str = './'):\n",
|
||||
" super().__init__()\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # self.dims is returned when you call dm.size()\n",
|
||||
" # Setting default dims here because we know them.\n",
|
||||
" # Could optionally be assigned dynamically in dm.setup()\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" self.num_classes = 10\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=32)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "H2Yoj-9M9dS7",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Defining the dataset agnostic `LitModel`\n",
|
||||
"\n",
|
||||
"Below, we define the same model as the `LitMNIST` model we made earlier. \n",
|
||||
"\n",
|
||||
"However, this time our model has the freedom to use any input data that we'd like 🔥."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PM2IISuOBDIu",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class LitModel(pl.LightningModule):\n",
|
||||
" \n",
|
||||
" def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # We take in input dimensions as parameters and use those to dynamically build model.\n",
|
||||
" self.channels = channels\n",
|
||||
" self.width = width\n",
|
||||
" self.height = height\n",
|
||||
" self.num_classes = num_classes\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(channels * width * height, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, num_classes)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.model(x)\n",
|
||||
" return F.log_softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" return pl.TrainResult(loss)\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
"\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" acc = accuracy(preds, y)\n",
|
||||
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||||
" result.log('val_loss', loss, prog_bar=True)\n",
|
||||
" result.log('val_acc', acc, prog_bar=True)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||||
" return optimizer"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "G4Z5olPe-xEo",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Training the `LitModel` using the `MNISTDataModule`\n",
|
||||
"\n",
|
||||
"Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "kV48vP_9mEli",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Init DataModule\n",
|
||||
"dm = MNISTDataModule()\n",
|
||||
"# Init model from datamodule's attributes\n",
|
||||
"model = LitModel(*dm.size(), dm.num_classes)\n",
|
||||
"# Init trainer\n",
|
||||
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n",
|
||||
"# Pass the datamodule as arg to trainer.fit to override model hooks :)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WNxrugIGRRv5",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Defining the CIFAR10 DataModule\n",
|
||||
"\n",
|
||||
"Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "1tkaYLU7RT5P",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class CIFAR10DataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_dir: str = './'):\n",
|
||||
" super().__init__()\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" self.dims = (3, 32, 32)\n",
|
||||
" self.num_classes = 10\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" CIFAR10(self.data_dir, train=True, download=True)\n",
|
||||
" CIFAR10(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.cifar_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.cifar_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.cifar_test, batch_size=32)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BrXxf3oX_gsZ",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Training the `LitModel` using the `CIFAR10DataModule`\n",
|
||||
"\n",
|
||||
"Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n",
|
||||
"\n",
|
||||
"The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "sd-SbWi_krdj",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"dm = CIFAR10DataModule()\n",
|
||||
"model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n",
|
||||
"trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,424 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "03-basic-gan.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "J37PBnE_x7IW",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# PyTorch Lightning Basic GAN Tutorial ⚡\n",
|
||||
"\n",
|
||||
"How to train a GAN!\n",
|
||||
"\n",
|
||||
"Main takeaways:\n",
|
||||
"1. Generator and discriminator are arbitraty PyTorch modules.\n",
|
||||
"2. training_step does both the generator and discriminator training.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kg2MKpRmybht",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"Lightning is easy to install. Simply `pip install pytorch-lightning`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "LfrJLKPFyhsK",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "BjEPuiVLyanw",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from argparse import ArgumentParser\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torchvision\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OuXJzr4G2uHV",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### MNIST DataModule\n",
|
||||
"\n",
|
||||
"Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "DOY_nHu328g7",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class MNISTDataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
|
||||
" super().__init__()\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.num_workers = num_workers\n",
|
||||
"\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # self.dims is returned when you call dm.size()\n",
|
||||
" # Setting default dims here because we know them.\n",
|
||||
" # Could optionally be assigned dynamically in dm.setup()\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" self.num_classes = 10\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tW3c0QrQyF9P",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### A. Generator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "0E2QDjl5yWtz",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class Generator(nn.Module):\n",
|
||||
" def __init__(self, latent_dim, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
" self.img_shape = img_shape\n",
|
||||
"\n",
|
||||
" def block(in_feat, out_feat, normalize=True):\n",
|
||||
" layers = [nn.Linear(in_feat, out_feat)]\n",
|
||||
" if normalize:\n",
|
||||
" layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
|
||||
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
|
||||
" return layers\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" *block(latent_dim, 128, normalize=False),\n",
|
||||
" *block(128, 256),\n",
|
||||
" *block(256, 512),\n",
|
||||
" *block(512, 1024),\n",
|
||||
" nn.Linear(1024, int(np.prod(img_shape))),\n",
|
||||
" nn.Tanh()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" img = self.model(z)\n",
|
||||
" img = img.view(img.size(0), *self.img_shape)\n",
|
||||
" return img"
|
||||
],
|
||||
"execution_count": 4,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "uyrltsGvyaI3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### B. Discriminator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Ed3MR3vnyxyW",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class Discriminator(nn.Module):\n",
|
||||
" def __init__(self, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Linear(int(np.prod(img_shape)), 512),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(256, 1),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, img):\n",
|
||||
" img_flat = img.view(img.size(0), -1)\n",
|
||||
" validity = self.model(img_flat)\n",
|
||||
"\n",
|
||||
" return validity"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BwUMom3ryySK",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### C. GAN\n",
|
||||
"\n",
|
||||
"#### A couple of cool features to check out in this example...\n",
|
||||
"\n",
|
||||
" - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
|
||||
" - Lightning will put your dataloader data on the right device automatically\n",
|
||||
" - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
|
||||
" - `type_as` is the way we recommend to do this.\n",
|
||||
" - This example shows how to use multiple dataloaders in your `LightningModule`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "3vKszYf6y1Vv",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
" class GAN(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" channels,\n",
|
||||
" width,\n",
|
||||
" height,\n",
|
||||
" latent_dim: int = 100,\n",
|
||||
" lr: float = 0.0002,\n",
|
||||
" b1: float = 0.5,\n",
|
||||
" b2: float = 0.999,\n",
|
||||
" batch_size: int = 64,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
" self.save_hyperparameters()\n",
|
||||
"\n",
|
||||
" # networks\n",
|
||||
" data_shape = (channels, width, height)\n",
|
||||
" self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
|
||||
" self.discriminator = Discriminator(img_shape=data_shape)\n",
|
||||
"\n",
|
||||
" self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" return self.generator(z)\n",
|
||||
"\n",
|
||||
" def adversarial_loss(self, y_hat, y):\n",
|
||||
" return F.binary_cross_entropy(y_hat, y)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
|
||||
" imgs, _ = batch\n",
|
||||
"\n",
|
||||
" # sample noise\n",
|
||||
" z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
|
||||
" z = z.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # train generator\n",
|
||||
" if optimizer_idx == 0:\n",
|
||||
"\n",
|
||||
" # generate images\n",
|
||||
" self.generated_imgs = self(z)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self.generated_imgs[:6]\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, 0)\n",
|
||||
"\n",
|
||||
" # ground truth result (ie: all fake)\n",
|
||||
" # put on GPU because we created this tensor inside training_loop\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # adversarial loss is binary cross-entropy\n",
|
||||
" g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
|
||||
" tqdm_dict = {'g_loss': g_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': g_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" # train discriminator\n",
|
||||
" if optimizer_idx == 1:\n",
|
||||
" # Measure discriminator's ability to classify real from generated samples\n",
|
||||
"\n",
|
||||
" # how well can it label as real?\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
|
||||
"\n",
|
||||
" # how well can it label as fake?\n",
|
||||
" fake = torch.zeros(imgs.size(0), 1)\n",
|
||||
" fake = fake.type_as(imgs)\n",
|
||||
"\n",
|
||||
" fake_loss = self.adversarial_loss(\n",
|
||||
" self.discriminator(self(z).detach()), fake)\n",
|
||||
"\n",
|
||||
" # discriminator loss is the average of these\n",
|
||||
" d_loss = (real_loss + fake_loss) / 2\n",
|
||||
" tqdm_dict = {'d_loss': d_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': d_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" lr = self.hparams.lr\n",
|
||||
" b1 = self.hparams.b1\n",
|
||||
" b2 = self.hparams.b2\n",
|
||||
"\n",
|
||||
" opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" return [opt_g, opt_d], []\n",
|
||||
"\n",
|
||||
" def on_epoch_end(self):\n",
|
||||
" z = self.validation_z.type_as(self.generator.model[0].weight)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self(z)\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
|
||||
],
|
||||
"execution_count": 6,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Ey5FmJPnzm_E",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"dm = MNISTDataModule()\n",
|
||||
"model = GAN(*dm.size())\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "MlECc7cHzolp",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,556 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "04-transformers-text-classification.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"toc_visible": true,
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8ag5ANQPJ_j9",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
|
||||
"\n",
|
||||
"This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
|
||||
"\n",
|
||||
"[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
|
||||
"\n",
|
||||
" - [HuggingFace nlp](https://github.com/huggingface/nlp)\n",
|
||||
" - [HuggingFace transformers](https://github.com/huggingface/transformers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fqlsVTj7McZ3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "OIhHrRL-MnKK",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"!pip install pytorch-lightning datasets transformers"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "6yuQT_ZQMpCg",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"from argparse import ArgumentParser\n",
|
||||
"from datetime import datetime\n",
|
||||
"from typing import Optional\n",
|
||||
"\n",
|
||||
"import nlp\n",
|
||||
"import numpy as np\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from transformers import (\n",
|
||||
" AdamW,\n",
|
||||
" AutoModelForSequenceClassification,\n",
|
||||
" AutoConfig,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" get_linear_schedule_with_warmup,\n",
|
||||
" glue_compute_metrics\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9ORJfiuiNZ_N",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## GLUE DataModule"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "jW9xQhZxMz1G",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class GLUEDataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" task_text_field_map = {\n",
|
||||
" 'cola': ['sentence'],\n",
|
||||
" 'sst2': ['sentence'],\n",
|
||||
" 'mrpc': ['sentence1', 'sentence2'],\n",
|
||||
" 'qqp': ['question1', 'question2'],\n",
|
||||
" 'stsb': ['sentence1', 'sentence2'],\n",
|
||||
" 'mnli': ['premise', 'hypothesis'],\n",
|
||||
" 'qnli': ['question', 'sentence'],\n",
|
||||
" 'rte': ['sentence1', 'sentence2'],\n",
|
||||
" 'wnli': ['sentence1', 'sentence2'],\n",
|
||||
" 'ax': ['premise', 'hypothesis']\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" glue_task_num_labels = {\n",
|
||||
" 'cola': 2,\n",
|
||||
" 'sst2': 2,\n",
|
||||
" 'mrpc': 2,\n",
|
||||
" 'qqp': 2,\n",
|
||||
" 'stsb': 1,\n",
|
||||
" 'mnli': 3,\n",
|
||||
" 'qnli': 2,\n",
|
||||
" 'rte': 2,\n",
|
||||
" 'wnli': 2,\n",
|
||||
" 'ax': 3\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" loader_columns = [\n",
|
||||
" 'nlp_idx',\n",
|
||||
" 'input_ids',\n",
|
||||
" 'token_type_ids',\n",
|
||||
" 'attention_mask',\n",
|
||||
" 'start_positions',\n",
|
||||
" 'end_positions',\n",
|
||||
" 'labels'\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" model_name_or_path: str,\n",
|
||||
" task_name: str ='mrpc',\n",
|
||||
" max_seq_length: int = 128,\n",
|
||||
" train_batch_size: int = 32,\n",
|
||||
" eval_batch_size: int = 32,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name_or_path = model_name_or_path\n",
|
||||
" self.task_name = task_name\n",
|
||||
" self.max_seq_length = max_seq_length\n",
|
||||
" self.train_batch_size = train_batch_size\n",
|
||||
" self.eval_batch_size = eval_batch_size\n",
|
||||
"\n",
|
||||
" self.text_fields = self.task_text_field_map[task_name]\n",
|
||||
" self.num_labels = self.glue_task_num_labels[task_name]\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage):\n",
|
||||
" self.dataset = nlp.load_dataset('glue', self.task_name)\n",
|
||||
"\n",
|
||||
" for split in self.dataset.keys():\n",
|
||||
" self.dataset[split] = self.dataset[split].map(\n",
|
||||
" self.convert_to_features,\n",
|
||||
" batched=True,\n",
|
||||
" remove_columns=['label'],\n",
|
||||
" )\n",
|
||||
" self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
|
||||
" self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
|
||||
"\n",
|
||||
" self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" nlp.load_dataset('glue', self.task_name)\n",
|
||||
" AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
|
||||
" \n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
|
||||
" \n",
|
||||
" def val_dataloader(self):\n",
|
||||
" if len(self.eval_splits) == 1:\n",
|
||||
" return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
|
||||
" elif len(self.eval_splits) > 1:\n",
|
||||
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" if len(self.eval_splits) == 1:\n",
|
||||
" return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
|
||||
" elif len(self.eval_splits) > 1:\n",
|
||||
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
|
||||
"\n",
|
||||
" def convert_to_features(self, example_batch, indices=None):\n",
|
||||
"\n",
|
||||
" # Either encode single sentence or sentence pairs\n",
|
||||
" if len(self.text_fields) > 1:\n",
|
||||
" texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
|
||||
" else:\n",
|
||||
" texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
|
||||
"\n",
|
||||
" # Tokenize the text/text pairs\n",
|
||||
" features = self.tokenizer.batch_encode_plus(\n",
|
||||
" texts_or_text_pairs,\n",
|
||||
" max_length=self.max_seq_length,\n",
|
||||
" pad_to_max_length=True,\n",
|
||||
" truncation=True\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Rename label to labels to make it easier to pass to model forward\n",
|
||||
" features['labels'] = example_batch['label']\n",
|
||||
"\n",
|
||||
" return features"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jQC3a6KuOpX3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"#### You could use this datamodule with standalone PyTorch if you wanted..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "JCMH3IAsNffF",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"dm = GLUEDataModule('distilbert-base-uncased')\n",
|
||||
"dm.prepare_data()\n",
|
||||
"dm.setup('fit')\n",
|
||||
"next(iter(dm.train_dataloader()))"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "l9fQ_67BO2Lj",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## GLUE Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "gtn5YGKYO65B",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class GLUETransformer(pl.LightningModule):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" model_name_or_path: str,\n",
|
||||
" num_labels: int,\n",
|
||||
" learning_rate: float = 2e-5,\n",
|
||||
" adam_epsilon: float = 1e-8,\n",
|
||||
" warmup_steps: int = 0,\n",
|
||||
" weight_decay: float = 0.0,\n",
|
||||
" train_batch_size: int = 32,\n",
|
||||
" eval_batch_size: int = 32,\n",
|
||||
" eval_splits: Optional[list] = None,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.save_hyperparameters()\n",
|
||||
"\n",
|
||||
" self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
|
||||
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
|
||||
" self.metric = nlp.load_metric(\n",
|
||||
" 'glue',\n",
|
||||
" self.hparams.task_name,\n",
|
||||
" experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, **inputs):\n",
|
||||
" return self.model(**inputs)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" outputs = self(**batch)\n",
|
||||
" loss = outputs[0]\n",
|
||||
" return pl.TrainResult(loss)\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
|
||||
" outputs = self(**batch)\n",
|
||||
" val_loss, logits = outputs[:2]\n",
|
||||
"\n",
|
||||
" if self.hparams.num_labels >= 1:\n",
|
||||
" preds = torch.argmax(logits, axis=1)\n",
|
||||
" elif self.hparams.num_labels == 1:\n",
|
||||
" preds = logits.squeeze()\n",
|
||||
"\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
"\n",
|
||||
" return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
|
||||
"\n",
|
||||
" def validation_epoch_end(self, outputs):\n",
|
||||
" if self.hparams.task_name == 'mnli':\n",
|
||||
" for i, output in enumerate(outputs):\n",
|
||||
" # matched or mismatched\n",
|
||||
" split = self.hparams.eval_splits[i].split('_')[-1]\n",
|
||||
" preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
|
||||
" labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
|
||||
" loss = torch.stack([x['loss'] for x in output]).mean()\n",
|
||||
" if i == 0:\n",
|
||||
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||||
" result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
|
||||
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(preds, labels).items()}\n",
|
||||
" result.log_dict(split_metrics, prog_bar=True)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
|
||||
" labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
|
||||
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
|
||||
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||||
" result.log('val_loss', loss, prog_bar=True)\n",
|
||||
" result.log_dict(self.metric.compute(preds, labels), prog_bar=True)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def setup(self, stage):\n",
|
||||
" if stage == 'fit':\n",
|
||||
" # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
|
||||
" train_loader = self.train_dataloader()\n",
|
||||
"\n",
|
||||
" # Calculate total steps\n",
|
||||
" self.total_steps = (\n",
|
||||
" (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
|
||||
" // self.hparams.accumulate_grad_batches\n",
|
||||
" * float(self.hparams.max_epochs)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
|
||||
" model = self.model\n",
|
||||
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
|
||||
" optimizer_grouped_parameters = [\n",
|
||||
" {\n",
|
||||
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
|
||||
" \"weight_decay\": self.hparams.weight_decay,\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
|
||||
" \"weight_decay\": 0.0,\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
|
||||
"\n",
|
||||
" scheduler = get_linear_schedule_with_warmup(\n",
|
||||
" optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
|
||||
" )\n",
|
||||
" scheduler = {\n",
|
||||
" 'scheduler': scheduler,\n",
|
||||
" 'interval': 'step',\n",
|
||||
" 'frequency': 1\n",
|
||||
" }\n",
|
||||
" return [optimizer], [scheduler]\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def add_model_specific_args(parent_parser):\n",
|
||||
" parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
|
||||
" parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
|
||||
" parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
|
||||
" parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
|
||||
" parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
|
||||
" return parser"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ha-NdIP_xbd3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### ⚡ Quick Tip \n",
|
||||
" - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "3dEHnl3RPlAR",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"def parse_args(args=None):\n",
|
||||
" parser = ArgumentParser()\n",
|
||||
" parser = pl.Trainer.add_argparse_args(parser)\n",
|
||||
" parser = GLUEDataModule.add_argparse_args(parser)\n",
|
||||
" parser = GLUETransformer.add_model_specific_args(parser)\n",
|
||||
" parser.add_argument('--seed', type=int, default=42)\n",
|
||||
" return parser.parse_args(args)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main(args):\n",
|
||||
" pl.seed_everything(args.seed)\n",
|
||||
" dm = GLUEDataModule.from_argparse_args(args)\n",
|
||||
" dm.prepare_data()\n",
|
||||
" dm.setup('fit')\n",
|
||||
" model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
|
||||
" trainer = pl.Trainer.from_argparse_args(args)\n",
|
||||
" return dm, model, trainer"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PkuLaeec3sJ-",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QSpueK5UPsN7",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## CoLA\n",
|
||||
"\n",
|
||||
"See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "NJnFmtpnPu0Y",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"mocked_args = \"\"\"\n",
|
||||
" --model_name_or_path albert-base-v2\n",
|
||||
" --task_name cola\n",
|
||||
" --max_epochs 3\n",
|
||||
" --gpus 1\"\"\".split()\n",
|
||||
"\n",
|
||||
"args = parse_args(mocked_args)\n",
|
||||
"dm, model, trainer = main(args)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_MrNsTnqdz4z",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## MRPC\n",
|
||||
"\n",
|
||||
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "LBwRxg9Cb3d-",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"mocked_args = \"\"\"\n",
|
||||
" --model_name_or_path distilbert-base-cased\n",
|
||||
" --task_name mrpc\n",
|
||||
" --max_epochs 3\n",
|
||||
" --gpus 1\"\"\".split()\n",
|
||||
"\n",
|
||||
"args = parse_args(mocked_args)\n",
|
||||
"dm, model, trainer = main(args)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "iZhbn0HzfdCu",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## MNLI\n",
|
||||
"\n",
|
||||
" - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
|
||||
"\n",
|
||||
" - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
|
||||
"\n",
|
||||
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "AvsZMOggfcWW",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"mocked_args = \"\"\"\n",
|
||||
" --model_name_or_path distilbert-base-uncased\n",
|
||||
" --task_name mnli\n",
|
||||
" --max_epochs 1\n",
|
||||
" --gpus 1\n",
|
||||
" --limit_train_batches 10\n",
|
||||
" --progress_bar_refresh_rate 20\"\"\".split()\n",
|
||||
"\n",
|
||||
"args = parse_args(mocked_args)\n",
|
||||
"dm, model, trainer = main(args)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
# Lightning Notebooks ⚡
|
||||
|
||||
## Official Notebooks
|
||||
|
||||
You can easily run any of the official notebooks by clicking the 'Open in Colab' links in the table below :smile:
|
||||
|
||||
| Notebook | Description | Colab Link |
|
||||
| :--- | :--- | :---: |
|
||||
| __MNIST Hello World__ | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01_mnist_hello_world.ipynb) |
|
||||
| __Datamodules__ | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10.| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02_datamodules.ipynb)|
|
||||
| __GAN__ | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03_basic_gan.ipynb) |
|
||||
| __BERT__ | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04_transformers_text_classification.ipynb) |
|
Loading…
Reference in New Issue