add congratulations at the end of our notebooks (#4555)

* add congratulations at the end of our notebooks

* udpate image
This commit is contained in:
chaton 2020-11-07 12:05:29 +00:00 committed by GitHub
parent 6e5f232f5c
commit 854c13673b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5004 additions and 4764 deletions

View File

@ -1,400 +1,448 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "01-mnist-hello-world.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i7XbLCXGkll9",
"colab_type": "text"
},
"source": [
"# Introduction to Pytorch Lightning ⚡\n",
"\n",
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
"\n",
"---\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2LODD6w9ixlT",
"colab_type": "text"
},
"source": [
"### Setup \n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zK7-Gg69kMnG",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install pytorch-lightning --quiet"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "w4_TYnt_keJi",
"colab_type": "code",
"colab": {}
},
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.metrics.functional import accuracy"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EHpyMPKFkVbZ",
"colab_type": "text"
},
"source": [
"## Simplest example\n",
"\n",
"Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
"\n",
"**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
]
},
{
"cell_type": "code",
"metadata": {
"id": "V7ELesz1kVQo",
"colab_type": "code",
"colab": {}
},
"source": [
"class MNISTModel(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super(MNISTModel, self).__init__()\n",
" self.l1 = torch.nn.Linear(28 * 28, 10)\n",
"\n",
" def forward(self, x):\n",
" return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
"\n",
" def training_step(self, batch, batch_nb):\n",
" x, y = batch\n",
" loss = F.cross_entropy(self(x), y)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=0.02)"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hIrtHg-Dv8TJ",
"colab_type": "text"
},
"source": [
"By using the `Trainer` you automatically get:\n",
"1. Tensorboard logging\n",
"2. Model checkpointing\n",
"3. Training and validation loop\n",
"4. early-stopping"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4Dk6Ykv8lI7X",
"colab_type": "code",
"colab": {}
},
"source": [
"# Init our model\n",
"mnist_model = MNISTModel()\n",
"\n",
"# Init DataLoader from MNIST Dataset\n",
"train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
"train_loader = DataLoader(train_ds, batch_size=32)\n",
"\n",
"# Initialize a trainer\n",
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
"\n",
"# Train the model ⚡\n",
"trainer.fit(mnist_model, train_loader)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KNpOoBeIjscS",
"colab_type": "text"
},
"source": [
"## A more complete MNIST Lightning Module Example\n",
"\n",
"That wasn't so hard was it?\n",
"\n",
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
"\n",
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
"\n",
"---\n",
"\n",
"### Note what the following built-in functions are doing:\n",
"\n",
"1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
"\n",
"2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
"\n",
"3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4DNItffri95Q",
"colab_type": "code",
"colab": {}
},
"source": [
"class LitMNIST(pl.LightningModule):\n",
" \n",
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
"\n",
" super().__init__()\n",
"\n",
" # Set our init args as class attributes\n",
" self.data_dir = data_dir\n",
" self.hidden_size = hidden_size\n",
" self.learning_rate = learning_rate\n",
"\n",
" # Hardcode some dataset specific attributes\n",
" self.num_classes = 10\n",
" self.dims = (1, 28, 28)\n",
" channels, width, height = self.dims\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # Define PyTorch model\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(channels * width * height, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, self.num_classes)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
"\n",
" # Calling self.log will surface up scalars for you in TensorBoard\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" # Here we just reuse the validation_step for testing\n",
" return self.validation_step(batch, batch_idx)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" return optimizer\n",
"\n",
" ####################\n",
" # DATA RELATED HOOKS\n",
" ####################\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=32)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=32)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=32)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Mb0U5Rk2kLBy",
"colab_type": "code",
"colab": {}
},
"source": [
"model = LitMNIST()\n",
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
"trainer.fit(model)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nht8AvMptY6I",
"colab_type": "text"
},
"source": [
"### Testing\n",
"\n",
"To test a model, call `trainer.test(model)`.\n",
"\n",
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "PA151FkLtprO",
"colab_type": "code",
"colab": {}
},
"source": [
"trainer.test()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "T3-3lbbNtr5T",
"colab_type": "text"
},
"source": [
"### Bonus Tip\n",
"\n",
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IFBwCbLet2r6",
"colab_type": "code",
"colab": {}
},
"source": [
"trainer.fit(model)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8TRyS5CCt3n9",
"colab_type": "text"
},
"source": [
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wizS-QiLuAYo",
"colab_type": "code",
"colab": {}
},
"source": [
"# Start tensorboard.\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir lightning_logs/"
],
"execution_count": null,
"outputs": []
}
]
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "i7XbLCXGkll9"
},
"source": [
"# Introduction to Pytorch Lightning ⚡\n",
"\n",
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
"\n",
"---\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2LODD6w9ixlT"
},
"source": [
"### Setup \n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "zK7-Gg69kMnG"
},
"outputs": [],
"source": [
"! pip install pytorch-lightning --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "w4_TYnt_keJi"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.metrics.functional import accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EHpyMPKFkVbZ"
},
"source": [
"## Simplest example\n",
"\n",
"Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
"\n",
"**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "V7ELesz1kVQo"
},
"outputs": [],
"source": [
"class MNISTModel(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super(MNISTModel, self).__init__()\n",
" self.l1 = torch.nn.Linear(28 * 28, 10)\n",
"\n",
" def forward(self, x):\n",
" return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
"\n",
" def training_step(self, batch, batch_nb):\n",
" x, y = batch\n",
" loss = F.cross_entropy(self(x), y)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=0.02)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hIrtHg-Dv8TJ"
},
"source": [
"By using the `Trainer` you automatically get:\n",
"1. Tensorboard logging\n",
"2. Model checkpointing\n",
"3. Training and validation loop\n",
"4. early-stopping"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4Dk6Ykv8lI7X"
},
"outputs": [],
"source": [
"# Init our model\n",
"mnist_model = MNISTModel()\n",
"\n",
"# Init DataLoader from MNIST Dataset\n",
"train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
"train_loader = DataLoader(train_ds, batch_size=32)\n",
"\n",
"# Initialize a trainer\n",
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
"\n",
"# Train the model ⚡\n",
"trainer.fit(mnist_model, train_loader)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KNpOoBeIjscS"
},
"source": [
"## A more complete MNIST Lightning Module Example\n",
"\n",
"That wasn't so hard was it?\n",
"\n",
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
"\n",
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
"\n",
"---\n",
"\n",
"### Note what the following built-in functions are doing:\n",
"\n",
"1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
"\n",
"2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
"\n",
"3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4DNItffri95Q"
},
"outputs": [],
"source": [
"class LitMNIST(pl.LightningModule):\n",
" \n",
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
"\n",
" super().__init__()\n",
"\n",
" # Set our init args as class attributes\n",
" self.data_dir = data_dir\n",
" self.hidden_size = hidden_size\n",
" self.learning_rate = learning_rate\n",
"\n",
" # Hardcode some dataset specific attributes\n",
" self.num_classes = 10\n",
" self.dims = (1, 28, 28)\n",
" channels, width, height = self.dims\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # Define PyTorch model\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(channels * width * height, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, self.num_classes)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
"\n",
" # Calling self.log will surface up scalars for you in TensorBoard\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" # Here we just reuse the validation_step for testing\n",
" return self.validation_step(batch, batch_idx)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" return optimizer\n",
"\n",
" ####################\n",
" # DATA RELATED HOOKS\n",
" ####################\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=32)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=32)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Mb0U5Rk2kLBy"
},
"outputs": [],
"source": [
"model = LitMNIST()\n",
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
"trainer.fit(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "nht8AvMptY6I"
},
"source": [
"### Testing\n",
"\n",
"To test a model, call `trainer.test(model)`.\n",
"\n",
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "PA151FkLtprO"
},
"outputs": [],
"source": [
"trainer.test()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "T3-3lbbNtr5T"
},
"source": [
"### Bonus Tip\n",
"\n",
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "IFBwCbLet2r6"
},
"outputs": [],
"source": [
"trainer.fit(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8TRyS5CCt3n9"
},
"source": [
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wizS-QiLuAYo"
},
"outputs": [],
"source": [
"# Start tensorboard.\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir lightning_logs/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<code style=\"color:#792ee5;\">\n",
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
"</code>\n",
"\n",
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
"\n",
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
"\n",
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
"\n",
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
"\n",
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
"\n",
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"\n",
"### Contributions !\n",
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
"\n",
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* You can also contribute your own notebooks with useful examples !\n",
"\n",
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
"\n",
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
"collapsed_sections": [],
"include_colab_link": true,
"name": "01-mnist-hello-world.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because it is too large Load Diff

View File

@ -1,424 +1,472 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "03-basic-gan.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J37PBnE_x7IW",
"colab_type": "text"
},
"source": [
"# PyTorch Lightning Basic GAN Tutorial ⚡\n",
"\n",
"How to train a GAN!\n",
"\n",
"Main takeaways:\n",
"1. Generator and discriminator are arbitrary PyTorch modules.\n",
"2. training_step does both the generator and discriminator training.\n",
"\n",
"---\n",
"\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kg2MKpRmybht",
"colab_type": "text"
},
"source": [
"### Setup\n",
"Lightning is easy to install. Simply `pip install pytorch-lightning`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LfrJLKPFyhsK",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install pytorch-lightning --quiet"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BjEPuiVLyanw",
"colab_type": "code",
"colab": {}
},
"source": [
"import os\n",
"from argparse import ArgumentParser\n",
"from collections import OrderedDict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import MNIST\n",
"\n",
"import pytorch_lightning as pl"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "OuXJzr4G2uHV",
"colab_type": "text"
},
"source": [
"### MNIST DataModule\n",
"\n",
"Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "DOY_nHu328g7",
"colab_type": "code",
"colab": {}
},
"source": [
"class MNISTDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
" self.batch_size = batch_size\n",
" self.num_workers = num_workers\n",
"\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # self.dims is returned when you call dm.size()\n",
" # Setting default dims here because we know them.\n",
" # Could optionally be assigned dynamically in dm.setup()\n",
" self.dims = (1, 28, 28)\n",
" self.num_classes = 10\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tW3c0QrQyF9P",
"colab_type": "text"
},
"source": [
"### A. Generator"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0E2QDjl5yWtz",
"colab_type": "code",
"colab": {}
},
"source": [
"class Generator(nn.Module):\n",
" def __init__(self, latent_dim, img_shape):\n",
" super().__init__()\n",
" self.img_shape = img_shape\n",
"\n",
" def block(in_feat, out_feat, normalize=True):\n",
" layers = [nn.Linear(in_feat, out_feat)]\n",
" if normalize:\n",
" layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
" return layers\n",
"\n",
" self.model = nn.Sequential(\n",
" *block(latent_dim, 128, normalize=False),\n",
" *block(128, 256),\n",
" *block(256, 512),\n",
" *block(512, 1024),\n",
" nn.Linear(1024, int(np.prod(img_shape))),\n",
" nn.Tanh()\n",
" )\n",
"\n",
" def forward(self, z):\n",
" img = self.model(z)\n",
" img = img.view(img.size(0), *self.img_shape)\n",
" return img"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uyrltsGvyaI3",
"colab_type": "text"
},
"source": [
"### B. Discriminator"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ed3MR3vnyxyW",
"colab_type": "code",
"colab": {}
},
"source": [
"class Discriminator(nn.Module):\n",
" def __init__(self, img_shape):\n",
" super().__init__()\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Linear(int(np.prod(img_shape)), 512),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Linear(512, 256),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Linear(256, 1),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, img):\n",
" img_flat = img.view(img.size(0), -1)\n",
" validity = self.model(img_flat)\n",
"\n",
" return validity"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BwUMom3ryySK",
"colab_type": "text"
},
"source": [
"### C. GAN\n",
"\n",
"#### A couple of cool features to check out in this example...\n",
"\n",
" - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
" - Lightning will put your dataloader data on the right device automatically\n",
" - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
" - `type_as` is the way we recommend to do this.\n",
" - This example shows how to use multiple dataloaders in your `LightningModule`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3vKszYf6y1Vv",
"colab_type": "code",
"colab": {}
},
"source": [
" class GAN(pl.LightningModule):\n",
"\n",
" def __init__(\n",
" self,\n",
" channels,\n",
" width,\n",
" height,\n",
" latent_dim: int = 100,\n",
" lr: float = 0.0002,\n",
" b1: float = 0.5,\n",
" b2: float = 0.999,\n",
" batch_size: int = 64,\n",
" **kwargs\n",
" ):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" # networks\n",
" data_shape = (channels, width, height)\n",
" self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
" self.discriminator = Discriminator(img_shape=data_shape)\n",
"\n",
" self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
"\n",
" self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
"\n",
" def forward(self, z):\n",
" return self.generator(z)\n",
"\n",
" def adversarial_loss(self, y_hat, y):\n",
" return F.binary_cross_entropy(y_hat, y)\n",
"\n",
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
" imgs, _ = batch\n",
"\n",
" # sample noise\n",
" z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
" z = z.type_as(imgs)\n",
"\n",
" # train generator\n",
" if optimizer_idx == 0:\n",
"\n",
" # generate images\n",
" self.generated_imgs = self(z)\n",
"\n",
" # log sampled images\n",
" sample_imgs = self.generated_imgs[:6]\n",
" grid = torchvision.utils.make_grid(sample_imgs)\n",
" self.logger.experiment.add_image('generated_images', grid, 0)\n",
"\n",
" # ground truth result (ie: all fake)\n",
" # put on GPU because we created this tensor inside training_loop\n",
" valid = torch.ones(imgs.size(0), 1)\n",
" valid = valid.type_as(imgs)\n",
"\n",
" # adversarial loss is binary cross-entropy\n",
" g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
" tqdm_dict = {'g_loss': g_loss}\n",
" output = OrderedDict({\n",
" 'loss': g_loss,\n",
" 'progress_bar': tqdm_dict,\n",
" 'log': tqdm_dict\n",
" })\n",
" return output\n",
"\n",
" # train discriminator\n",
" if optimizer_idx == 1:\n",
" # Measure discriminator's ability to classify real from generated samples\n",
"\n",
" # how well can it label as real?\n",
" valid = torch.ones(imgs.size(0), 1)\n",
" valid = valid.type_as(imgs)\n",
"\n",
" real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
"\n",
" # how well can it label as fake?\n",
" fake = torch.zeros(imgs.size(0), 1)\n",
" fake = fake.type_as(imgs)\n",
"\n",
" fake_loss = self.adversarial_loss(\n",
" self.discriminator(self(z).detach()), fake)\n",
"\n",
" # discriminator loss is the average of these\n",
" d_loss = (real_loss + fake_loss) / 2\n",
" tqdm_dict = {'d_loss': d_loss}\n",
" output = OrderedDict({\n",
" 'loss': d_loss,\n",
" 'progress_bar': tqdm_dict,\n",
" 'log': tqdm_dict\n",
" })\n",
" return output\n",
"\n",
" def configure_optimizers(self):\n",
" lr = self.hparams.lr\n",
" b1 = self.hparams.b1\n",
" b2 = self.hparams.b2\n",
"\n",
" opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
" opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
" return [opt_g, opt_d], []\n",
"\n",
" def on_epoch_end(self):\n",
" z = self.validation_z.type_as(self.generator.model[0].weight)\n",
"\n",
" # log sampled images\n",
" sample_imgs = self(z)\n",
" grid = torchvision.utils.make_grid(sample_imgs)\n",
" self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ey5FmJPnzm_E",
"colab_type": "code",
"colab": {}
},
"source": [
"dm = MNISTDataModule()\n",
"model = GAN(*dm.size())\n",
"trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MlECc7cHzolp",
"colab_type": "code",
"colab": {}
},
"source": [
"# Start tensorboard.\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir lightning_logs/"
],
"execution_count": null,
"outputs": []
}
]
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "J37PBnE_x7IW"
},
"source": [
"# PyTorch Lightning Basic GAN Tutorial ⚡\n",
"\n",
"How to train a GAN!\n",
"\n",
"Main takeaways:\n",
"1. Generator and discriminator are arbitrary PyTorch modules.\n",
"2. training_step does both the generator and discriminator training.\n",
"\n",
"---\n",
"\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kg2MKpRmybht"
},
"source": [
"### Setup\n",
"Lightning is easy to install. Simply `pip install pytorch-lightning`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LfrJLKPFyhsK"
},
"outputs": [],
"source": [
"! pip install pytorch-lightning --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BjEPuiVLyanw"
},
"outputs": [],
"source": [
"import os\n",
"from argparse import ArgumentParser\n",
"from collections import OrderedDict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import MNIST\n",
"\n",
"import pytorch_lightning as pl"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OuXJzr4G2uHV"
},
"source": [
"### MNIST DataModule\n",
"\n",
"Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DOY_nHu328g7"
},
"outputs": [],
"source": [
"class MNISTDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
" self.batch_size = batch_size\n",
" self.num_workers = num_workers\n",
"\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # self.dims is returned when you call dm.size()\n",
" # Setting default dims here because we know them.\n",
" # Could optionally be assigned dynamically in dm.setup()\n",
" self.dims = (1, 28, 28)\n",
" self.num_classes = 10\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "tW3c0QrQyF9P"
},
"source": [
"### A. Generator"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "0E2QDjl5yWtz"
},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
" def __init__(self, latent_dim, img_shape):\n",
" super().__init__()\n",
" self.img_shape = img_shape\n",
"\n",
" def block(in_feat, out_feat, normalize=True):\n",
" layers = [nn.Linear(in_feat, out_feat)]\n",
" if normalize:\n",
" layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
" return layers\n",
"\n",
" self.model = nn.Sequential(\n",
" *block(latent_dim, 128, normalize=False),\n",
" *block(128, 256),\n",
" *block(256, 512),\n",
" *block(512, 1024),\n",
" nn.Linear(1024, int(np.prod(img_shape))),\n",
" nn.Tanh()\n",
" )\n",
"\n",
" def forward(self, z):\n",
" img = self.model(z)\n",
" img = img.view(img.size(0), *self.img_shape)\n",
" return img"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uyrltsGvyaI3"
},
"source": [
"### B. Discriminator"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ed3MR3vnyxyW"
},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
" def __init__(self, img_shape):\n",
" super().__init__()\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Linear(int(np.prod(img_shape)), 512),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Linear(512, 256),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Linear(256, 1),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, img):\n",
" img_flat = img.view(img.size(0), -1)\n",
" validity = self.model(img_flat)\n",
"\n",
" return validity"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BwUMom3ryySK"
},
"source": [
"### C. GAN\n",
"\n",
"#### A couple of cool features to check out in this example...\n",
"\n",
" - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
" - Lightning will put your dataloader data on the right device automatically\n",
" - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
" - `type_as` is the way we recommend to do this.\n",
" - This example shows how to use multiple dataloaders in your `LightningModule`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "3vKszYf6y1Vv"
},
"outputs": [],
"source": [
" class GAN(pl.LightningModule):\n",
"\n",
" def __init__(\n",
" self,\n",
" channels,\n",
" width,\n",
" height,\n",
" latent_dim: int = 100,\n",
" lr: float = 0.0002,\n",
" b1: float = 0.5,\n",
" b2: float = 0.999,\n",
" batch_size: int = 64,\n",
" **kwargs\n",
" ):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" # networks\n",
" data_shape = (channels, width, height)\n",
" self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
" self.discriminator = Discriminator(img_shape=data_shape)\n",
"\n",
" self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
"\n",
" self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
"\n",
" def forward(self, z):\n",
" return self.generator(z)\n",
"\n",
" def adversarial_loss(self, y_hat, y):\n",
" return F.binary_cross_entropy(y_hat, y)\n",
"\n",
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
" imgs, _ = batch\n",
"\n",
" # sample noise\n",
" z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
" z = z.type_as(imgs)\n",
"\n",
" # train generator\n",
" if optimizer_idx == 0:\n",
"\n",
" # generate images\n",
" self.generated_imgs = self(z)\n",
"\n",
" # log sampled images\n",
" sample_imgs = self.generated_imgs[:6]\n",
" grid = torchvision.utils.make_grid(sample_imgs)\n",
" self.logger.experiment.add_image('generated_images', grid, 0)\n",
"\n",
" # ground truth result (ie: all fake)\n",
" # put on GPU because we created this tensor inside training_loop\n",
" valid = torch.ones(imgs.size(0), 1)\n",
" valid = valid.type_as(imgs)\n",
"\n",
" # adversarial loss is binary cross-entropy\n",
" g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
" tqdm_dict = {'g_loss': g_loss}\n",
" output = OrderedDict({\n",
" 'loss': g_loss,\n",
" 'progress_bar': tqdm_dict,\n",
" 'log': tqdm_dict\n",
" })\n",
" return output\n",
"\n",
" # train discriminator\n",
" if optimizer_idx == 1:\n",
" # Measure discriminator's ability to classify real from generated samples\n",
"\n",
" # how well can it label as real?\n",
" valid = torch.ones(imgs.size(0), 1)\n",
" valid = valid.type_as(imgs)\n",
"\n",
" real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
"\n",
" # how well can it label as fake?\n",
" fake = torch.zeros(imgs.size(0), 1)\n",
" fake = fake.type_as(imgs)\n",
"\n",
" fake_loss = self.adversarial_loss(\n",
" self.discriminator(self(z).detach()), fake)\n",
"\n",
" # discriminator loss is the average of these\n",
" d_loss = (real_loss + fake_loss) / 2\n",
" tqdm_dict = {'d_loss': d_loss}\n",
" output = OrderedDict({\n",
" 'loss': d_loss,\n",
" 'progress_bar': tqdm_dict,\n",
" 'log': tqdm_dict\n",
" })\n",
" return output\n",
"\n",
" def configure_optimizers(self):\n",
" lr = self.hparams.lr\n",
" b1 = self.hparams.b1\n",
" b2 = self.hparams.b2\n",
"\n",
" opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
" opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
" return [opt_g, opt_d], []\n",
"\n",
" def on_epoch_end(self):\n",
" z = self.validation_z.type_as(self.generator.model[0].weight)\n",
"\n",
" # log sampled images\n",
" sample_imgs = self(z)\n",
" grid = torchvision.utils.make_grid(sample_imgs)\n",
" self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ey5FmJPnzm_E"
},
"outputs": [],
"source": [
"dm = MNISTDataModule()\n",
"model = GAN(*dm.size())\n",
"trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
"trainer.fit(model, dm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MlECc7cHzolp"
},
"outputs": [],
"source": [
"# Start tensorboard.\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir lightning_logs/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<code style=\"color:#792ee5;\">\n",
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
"</code>\n",
"\n",
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
"\n",
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
"\n",
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
"\n",
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
"\n",
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
"\n",
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"\n",
"### Contributions !\n",
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
"\n",
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* You can also contribute your own notebooks with useful examples !\n",
"\n",
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
"\n",
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "03-basic-gan.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long