add congratulations at the end of our notebooks (#4555)
* add congratulations at the end of our notebooks * udpate image
This commit is contained in:
parent
6e5f232f5c
commit
854c13673b
|
@ -1,400 +1,448 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "01-mnist-hello-world.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "i7XbLCXGkll9",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# Introduction to Pytorch Lightning ⚡\n",
|
||||
"\n",
|
||||
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2LODD6w9ixlT",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup \n",
|
||||
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "zK7-Gg69kMnG",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "w4_TYnt_keJi",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.nn import functional as F\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision import transforms\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.metrics.functional import accuracy"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "EHpyMPKFkVbZ",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Simplest example\n",
|
||||
"\n",
|
||||
"Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
|
||||
"\n",
|
||||
"**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "V7ELesz1kVQo",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class MNISTModel(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MNISTModel, self).__init__()\n",
|
||||
" self.l1 = torch.nn.Linear(28 * 28, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_nb):\n",
|
||||
" x, y = batch\n",
|
||||
" loss = F.cross_entropy(self(x), y)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" return torch.optim.Adam(self.parameters(), lr=0.02)"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hIrtHg-Dv8TJ",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"By using the `Trainer` you automatically get:\n",
|
||||
"1. Tensorboard logging\n",
|
||||
"2. Model checkpointing\n",
|
||||
"3. Training and validation loop\n",
|
||||
"4. early-stopping"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "4Dk6Ykv8lI7X",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Init our model\n",
|
||||
"mnist_model = MNISTModel()\n",
|
||||
"\n",
|
||||
"# Init DataLoader from MNIST Dataset\n",
|
||||
"train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
|
||||
"train_loader = DataLoader(train_ds, batch_size=32)\n",
|
||||
"\n",
|
||||
"# Initialize a trainer\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"\n",
|
||||
"# Train the model ⚡\n",
|
||||
"trainer.fit(mnist_model, train_loader)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "KNpOoBeIjscS",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## A more complete MNIST Lightning Module Example\n",
|
||||
"\n",
|
||||
"That wasn't so hard was it?\n",
|
||||
"\n",
|
||||
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
|
||||
"\n",
|
||||
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Note what the following built-in functions are doing:\n",
|
||||
"\n",
|
||||
"1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
|
||||
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
|
||||
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
|
||||
"\n",
|
||||
"2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
|
||||
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
|
||||
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
|
||||
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
|
||||
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
|
||||
"\n",
|
||||
"3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
|
||||
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "4DNItffri95Q",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class LitMNIST(pl.LightningModule):\n",
|
||||
" \n",
|
||||
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # Set our init args as class attributes\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
"\n",
|
||||
" # Hardcode some dataset specific attributes\n",
|
||||
" self.num_classes = 10\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" channels, width, height = self.dims\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # Define PyTorch model\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(channels * width * height, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, self.num_classes)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.model(x)\n",
|
||||
" return F.log_softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" acc = accuracy(preds, y)\n",
|
||||
"\n",
|
||||
" # Calling self.log will surface up scalars for you in TensorBoard\n",
|
||||
" self.log('val_loss', loss, prog_bar=True)\n",
|
||||
" self.log('val_acc', acc, prog_bar=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def test_step(self, batch, batch_idx):\n",
|
||||
" # Here we just reuse the validation_step for testing\n",
|
||||
" return self.validation_step(batch, batch_idx)\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||||
" return optimizer\n",
|
||||
"\n",
|
||||
" ####################\n",
|
||||
" # DATA RELATED HOOKS\n",
|
||||
" ####################\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=32)"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Mb0U5Rk2kLBy",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"model = LitMNIST()\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nht8AvMptY6I",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Testing\n",
|
||||
"\n",
|
||||
"To test a model, call `trainer.test(model)`.\n",
|
||||
"\n",
|
||||
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PA151FkLtprO",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"trainer.test()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "T3-3lbbNtr5T",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Bonus Tip\n",
|
||||
"\n",
|
||||
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "IFBwCbLet2r6",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"trainer.fit(model)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8TRyS5CCt3n9",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "wizS-QiLuAYo",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "i7XbLCXGkll9"
|
||||
},
|
||||
"source": [
|
||||
"# Introduction to Pytorch Lightning ⚡\n",
|
||||
"\n",
|
||||
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "2LODD6w9ixlT"
|
||||
},
|
||||
"source": [
|
||||
"### Setup \n",
|
||||
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "zK7-Gg69kMnG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "w4_TYnt_keJi"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.nn import functional as F\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision import transforms\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.metrics.functional import accuracy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "EHpyMPKFkVbZ"
|
||||
},
|
||||
"source": [
|
||||
"## Simplest example\n",
|
||||
"\n",
|
||||
"Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
|
||||
"\n",
|
||||
"**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "V7ELesz1kVQo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MNISTModel(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MNISTModel, self).__init__()\n",
|
||||
" self.l1 = torch.nn.Linear(28 * 28, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_nb):\n",
|
||||
" x, y = batch\n",
|
||||
" loss = F.cross_entropy(self(x), y)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" return torch.optim.Adam(self.parameters(), lr=0.02)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "hIrtHg-Dv8TJ"
|
||||
},
|
||||
"source": [
|
||||
"By using the `Trainer` you automatically get:\n",
|
||||
"1. Tensorboard logging\n",
|
||||
"2. Model checkpointing\n",
|
||||
"3. Training and validation loop\n",
|
||||
"4. early-stopping"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "4Dk6Ykv8lI7X"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Init our model\n",
|
||||
"mnist_model = MNISTModel()\n",
|
||||
"\n",
|
||||
"# Init DataLoader from MNIST Dataset\n",
|
||||
"train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
|
||||
"train_loader = DataLoader(train_ds, batch_size=32)\n",
|
||||
"\n",
|
||||
"# Initialize a trainer\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"\n",
|
||||
"# Train the model ⚡\n",
|
||||
"trainer.fit(mnist_model, train_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "KNpOoBeIjscS"
|
||||
},
|
||||
"source": [
|
||||
"## A more complete MNIST Lightning Module Example\n",
|
||||
"\n",
|
||||
"That wasn't so hard was it?\n",
|
||||
"\n",
|
||||
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
|
||||
"\n",
|
||||
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Note what the following built-in functions are doing:\n",
|
||||
"\n",
|
||||
"1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
|
||||
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
|
||||
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
|
||||
"\n",
|
||||
"2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
|
||||
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
|
||||
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
|
||||
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
|
||||
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
|
||||
"\n",
|
||||
"3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
|
||||
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "4DNItffri95Q"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LitMNIST(pl.LightningModule):\n",
|
||||
" \n",
|
||||
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # Set our init args as class attributes\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
"\n",
|
||||
" # Hardcode some dataset specific attributes\n",
|
||||
" self.num_classes = 10\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" channels, width, height = self.dims\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # Define PyTorch model\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(channels * width * height, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, hidden_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(0.1),\n",
|
||||
" nn.Linear(hidden_size, self.num_classes)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.model(x)\n",
|
||||
" return F.log_softmax(x, dim=1)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" x, y = batch\n",
|
||||
" logits = self(x)\n",
|
||||
" loss = F.nll_loss(logits, y)\n",
|
||||
" preds = torch.argmax(logits, dim=1)\n",
|
||||
" acc = accuracy(preds, y)\n",
|
||||
"\n",
|
||||
" # Calling self.log will surface up scalars for you in TensorBoard\n",
|
||||
" self.log('val_loss', loss, prog_bar=True)\n",
|
||||
" self.log('val_acc', acc, prog_bar=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def test_step(self, batch, batch_idx):\n",
|
||||
" # Here we just reuse the validation_step for testing\n",
|
||||
" return self.validation_step(batch, batch_idx)\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||||
" return optimizer\n",
|
||||
"\n",
|
||||
" ####################\n",
|
||||
" # DATA RELATED HOOKS\n",
|
||||
" ####################\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Mb0U5Rk2kLBy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = LitMNIST()\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "nht8AvMptY6I"
|
||||
},
|
||||
"source": [
|
||||
"### Testing\n",
|
||||
"\n",
|
||||
"To test a model, call `trainer.test(model)`.\n",
|
||||
"\n",
|
||||
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "PA151FkLtprO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "T3-3lbbNtr5T"
|
||||
},
|
||||
"source": [
|
||||
"### Bonus Tip\n",
|
||||
"\n",
|
||||
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "IFBwCbLet2r6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "8TRyS5CCt3n9"
|
||||
},
|
||||
"source": [
|
||||
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "wizS-QiLuAYo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<code style=\"color:#792ee5;\">\n",
|
||||
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
|
||||
"</code>\n",
|
||||
"\n",
|
||||
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
|
||||
"\n",
|
||||
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
|
||||
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
|
||||
"\n",
|
||||
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
|
||||
"\n",
|
||||
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
|
||||
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
|
||||
"\n",
|
||||
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
||||
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
|
||||
"\n",
|
||||
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
||||
"\n",
|
||||
"### Contributions !\n",
|
||||
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
|
||||
"\n",
|
||||
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
||||
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
||||
"* You can also contribute your own notebooks with useful examples !\n",
|
||||
"\n",
|
||||
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
|
||||
"\n",
|
||||
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
|
||||
"collapsed_sections": [],
|
||||
"include_colab_link": true,
|
||||
"name": "01-mnist-hello-world.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,424 +1,472 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "03-basic-gan.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "J37PBnE_x7IW",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"# PyTorch Lightning Basic GAN Tutorial ⚡\n",
|
||||
"\n",
|
||||
"How to train a GAN!\n",
|
||||
"\n",
|
||||
"Main takeaways:\n",
|
||||
"1. Generator and discriminator are arbitrary PyTorch modules.\n",
|
||||
"2. training_step does both the generator and discriminator training.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kg2MKpRmybht",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"Lightning is easy to install. Simply `pip install pytorch-lightning`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "LfrJLKPFyhsK",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "BjEPuiVLyanw",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from argparse import ArgumentParser\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torchvision\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OuXJzr4G2uHV",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### MNIST DataModule\n",
|
||||
"\n",
|
||||
"Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "DOY_nHu328g7",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class MNISTDataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
|
||||
" super().__init__()\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.num_workers = num_workers\n",
|
||||
"\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # self.dims is returned when you call dm.size()\n",
|
||||
" # Setting default dims here because we know them.\n",
|
||||
" # Could optionally be assigned dynamically in dm.setup()\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" self.num_classes = 10\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tW3c0QrQyF9P",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### A. Generator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "0E2QDjl5yWtz",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class Generator(nn.Module):\n",
|
||||
" def __init__(self, latent_dim, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
" self.img_shape = img_shape\n",
|
||||
"\n",
|
||||
" def block(in_feat, out_feat, normalize=True):\n",
|
||||
" layers = [nn.Linear(in_feat, out_feat)]\n",
|
||||
" if normalize:\n",
|
||||
" layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
|
||||
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
|
||||
" return layers\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" *block(latent_dim, 128, normalize=False),\n",
|
||||
" *block(128, 256),\n",
|
||||
" *block(256, 512),\n",
|
||||
" *block(512, 1024),\n",
|
||||
" nn.Linear(1024, int(np.prod(img_shape))),\n",
|
||||
" nn.Tanh()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" img = self.model(z)\n",
|
||||
" img = img.view(img.size(0), *self.img_shape)\n",
|
||||
" return img"
|
||||
],
|
||||
"execution_count": 4,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "uyrltsGvyaI3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### B. Discriminator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Ed3MR3vnyxyW",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"class Discriminator(nn.Module):\n",
|
||||
" def __init__(self, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Linear(int(np.prod(img_shape)), 512),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(256, 1),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, img):\n",
|
||||
" img_flat = img.view(img.size(0), -1)\n",
|
||||
" validity = self.model(img_flat)\n",
|
||||
"\n",
|
||||
" return validity"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BwUMom3ryySK",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"### C. GAN\n",
|
||||
"\n",
|
||||
"#### A couple of cool features to check out in this example...\n",
|
||||
"\n",
|
||||
" - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
|
||||
" - Lightning will put your dataloader data on the right device automatically\n",
|
||||
" - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
|
||||
" - `type_as` is the way we recommend to do this.\n",
|
||||
" - This example shows how to use multiple dataloaders in your `LightningModule`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "3vKszYf6y1Vv",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
" class GAN(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" channels,\n",
|
||||
" width,\n",
|
||||
" height,\n",
|
||||
" latent_dim: int = 100,\n",
|
||||
" lr: float = 0.0002,\n",
|
||||
" b1: float = 0.5,\n",
|
||||
" b2: float = 0.999,\n",
|
||||
" batch_size: int = 64,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
" self.save_hyperparameters()\n",
|
||||
"\n",
|
||||
" # networks\n",
|
||||
" data_shape = (channels, width, height)\n",
|
||||
" self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
|
||||
" self.discriminator = Discriminator(img_shape=data_shape)\n",
|
||||
"\n",
|
||||
" self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" return self.generator(z)\n",
|
||||
"\n",
|
||||
" def adversarial_loss(self, y_hat, y):\n",
|
||||
" return F.binary_cross_entropy(y_hat, y)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
|
||||
" imgs, _ = batch\n",
|
||||
"\n",
|
||||
" # sample noise\n",
|
||||
" z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
|
||||
" z = z.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # train generator\n",
|
||||
" if optimizer_idx == 0:\n",
|
||||
"\n",
|
||||
" # generate images\n",
|
||||
" self.generated_imgs = self(z)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self.generated_imgs[:6]\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, 0)\n",
|
||||
"\n",
|
||||
" # ground truth result (ie: all fake)\n",
|
||||
" # put on GPU because we created this tensor inside training_loop\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # adversarial loss is binary cross-entropy\n",
|
||||
" g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
|
||||
" tqdm_dict = {'g_loss': g_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': g_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" # train discriminator\n",
|
||||
" if optimizer_idx == 1:\n",
|
||||
" # Measure discriminator's ability to classify real from generated samples\n",
|
||||
"\n",
|
||||
" # how well can it label as real?\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
|
||||
"\n",
|
||||
" # how well can it label as fake?\n",
|
||||
" fake = torch.zeros(imgs.size(0), 1)\n",
|
||||
" fake = fake.type_as(imgs)\n",
|
||||
"\n",
|
||||
" fake_loss = self.adversarial_loss(\n",
|
||||
" self.discriminator(self(z).detach()), fake)\n",
|
||||
"\n",
|
||||
" # discriminator loss is the average of these\n",
|
||||
" d_loss = (real_loss + fake_loss) / 2\n",
|
||||
" tqdm_dict = {'d_loss': d_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': d_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" lr = self.hparams.lr\n",
|
||||
" b1 = self.hparams.b1\n",
|
||||
" b2 = self.hparams.b2\n",
|
||||
"\n",
|
||||
" opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" return [opt_g, opt_d], []\n",
|
||||
"\n",
|
||||
" def on_epoch_end(self):\n",
|
||||
" z = self.validation_z.type_as(self.generator.model[0].weight)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self(z)\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
|
||||
],
|
||||
"execution_count": 6,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "Ey5FmJPnzm_E",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"dm = MNISTDataModule()\n",
|
||||
"model = GAN(*dm.size())\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "MlECc7cHzolp",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "J37PBnE_x7IW"
|
||||
},
|
||||
"source": [
|
||||
"# PyTorch Lightning Basic GAN Tutorial ⚡\n",
|
||||
"\n",
|
||||
"How to train a GAN!\n",
|
||||
"\n",
|
||||
"Main takeaways:\n",
|
||||
"1. Generator and discriminator are arbitrary PyTorch modules.\n",
|
||||
"2. training_step does both the generator and discriminator training.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "kg2MKpRmybht"
|
||||
},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"Lightning is easy to install. Simply `pip install pytorch-lightning`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "LfrJLKPFyhsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install pytorch-lightning --quiet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "BjEPuiVLyanw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from argparse import ArgumentParser\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torchvision\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from torch.utils.data import DataLoader, random_split\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "OuXJzr4G2uHV"
|
||||
},
|
||||
"source": [
|
||||
"### MNIST DataModule\n",
|
||||
"\n",
|
||||
"Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "DOY_nHu328g7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MNISTDataModule(pl.LightningDataModule):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n",
|
||||
" super().__init__()\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.num_workers = num_workers\n",
|
||||
"\n",
|
||||
" self.transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # self.dims is returned when you call dm.size()\n",
|
||||
" # Setting default dims here because we know them.\n",
|
||||
" # Could optionally be assigned dynamically in dm.setup()\n",
|
||||
" self.dims = (1, 28, 28)\n",
|
||||
" self.num_classes = 10\n",
|
||||
"\n",
|
||||
" def prepare_data(self):\n",
|
||||
" # download\n",
|
||||
" MNIST(self.data_dir, train=True, download=True)\n",
|
||||
" MNIST(self.data_dir, train=False, download=True)\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
"\n",
|
||||
" # Assign train/val datasets for use in dataloaders\n",
|
||||
" if stage == 'fit' or stage is None:\n",
|
||||
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||||
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||||
"\n",
|
||||
" # Assign test dataset for use in dataloader(s)\n",
|
||||
" if stage == 'test' or stage is None:\n",
|
||||
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "tW3c0QrQyF9P"
|
||||
},
|
||||
"source": [
|
||||
"### A. Generator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "0E2QDjl5yWtz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Generator(nn.Module):\n",
|
||||
" def __init__(self, latent_dim, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
" self.img_shape = img_shape\n",
|
||||
"\n",
|
||||
" def block(in_feat, out_feat, normalize=True):\n",
|
||||
" layers = [nn.Linear(in_feat, out_feat)]\n",
|
||||
" if normalize:\n",
|
||||
" layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
|
||||
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
|
||||
" return layers\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" *block(latent_dim, 128, normalize=False),\n",
|
||||
" *block(128, 256),\n",
|
||||
" *block(256, 512),\n",
|
||||
" *block(512, 1024),\n",
|
||||
" nn.Linear(1024, int(np.prod(img_shape))),\n",
|
||||
" nn.Tanh()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" img = self.model(z)\n",
|
||||
" img = img.view(img.size(0), *self.img_shape)\n",
|
||||
" return img"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "uyrltsGvyaI3"
|
||||
},
|
||||
"source": [
|
||||
"### B. Discriminator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Ed3MR3vnyxyW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Discriminator(nn.Module):\n",
|
||||
" def __init__(self, img_shape):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Linear(int(np.prod(img_shape)), 512),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.Linear(256, 1),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, img):\n",
|
||||
" img_flat = img.view(img.size(0), -1)\n",
|
||||
" validity = self.model(img_flat)\n",
|
||||
"\n",
|
||||
" return validity"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "BwUMom3ryySK"
|
||||
},
|
||||
"source": [
|
||||
"### C. GAN\n",
|
||||
"\n",
|
||||
"#### A couple of cool features to check out in this example...\n",
|
||||
"\n",
|
||||
" - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n",
|
||||
" - Lightning will put your dataloader data on the right device automatically\n",
|
||||
" - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n",
|
||||
" - `type_as` is the way we recommend to do this.\n",
|
||||
" - This example shows how to use multiple dataloaders in your `LightningModule`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "3vKszYf6y1Vv"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" class GAN(pl.LightningModule):\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" channels,\n",
|
||||
" width,\n",
|
||||
" height,\n",
|
||||
" latent_dim: int = 100,\n",
|
||||
" lr: float = 0.0002,\n",
|
||||
" b1: float = 0.5,\n",
|
||||
" b2: float = 0.999,\n",
|
||||
" batch_size: int = 64,\n",
|
||||
" **kwargs\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
" self.save_hyperparameters()\n",
|
||||
"\n",
|
||||
" # networks\n",
|
||||
" data_shape = (channels, width, height)\n",
|
||||
" self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n",
|
||||
" self.discriminator = Discriminator(img_shape=data_shape)\n",
|
||||
"\n",
|
||||
" self.validation_z = torch.randn(8, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, z):\n",
|
||||
" return self.generator(z)\n",
|
||||
"\n",
|
||||
" def adversarial_loss(self, y_hat, y):\n",
|
||||
" return F.binary_cross_entropy(y_hat, y)\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx, optimizer_idx):\n",
|
||||
" imgs, _ = batch\n",
|
||||
"\n",
|
||||
" # sample noise\n",
|
||||
" z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n",
|
||||
" z = z.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # train generator\n",
|
||||
" if optimizer_idx == 0:\n",
|
||||
"\n",
|
||||
" # generate images\n",
|
||||
" self.generated_imgs = self(z)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self.generated_imgs[:6]\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, 0)\n",
|
||||
"\n",
|
||||
" # ground truth result (ie: all fake)\n",
|
||||
" # put on GPU because we created this tensor inside training_loop\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" # adversarial loss is binary cross-entropy\n",
|
||||
" g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n",
|
||||
" tqdm_dict = {'g_loss': g_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': g_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" # train discriminator\n",
|
||||
" if optimizer_idx == 1:\n",
|
||||
" # Measure discriminator's ability to classify real from generated samples\n",
|
||||
"\n",
|
||||
" # how well can it label as real?\n",
|
||||
" valid = torch.ones(imgs.size(0), 1)\n",
|
||||
" valid = valid.type_as(imgs)\n",
|
||||
"\n",
|
||||
" real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n",
|
||||
"\n",
|
||||
" # how well can it label as fake?\n",
|
||||
" fake = torch.zeros(imgs.size(0), 1)\n",
|
||||
" fake = fake.type_as(imgs)\n",
|
||||
"\n",
|
||||
" fake_loss = self.adversarial_loss(\n",
|
||||
" self.discriminator(self(z).detach()), fake)\n",
|
||||
"\n",
|
||||
" # discriminator loss is the average of these\n",
|
||||
" d_loss = (real_loss + fake_loss) / 2\n",
|
||||
" tqdm_dict = {'d_loss': d_loss}\n",
|
||||
" output = OrderedDict({\n",
|
||||
" 'loss': d_loss,\n",
|
||||
" 'progress_bar': tqdm_dict,\n",
|
||||
" 'log': tqdm_dict\n",
|
||||
" })\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" lr = self.hparams.lr\n",
|
||||
" b1 = self.hparams.b1\n",
|
||||
" b2 = self.hparams.b2\n",
|
||||
"\n",
|
||||
" opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n",
|
||||
" return [opt_g, opt_d], []\n",
|
||||
"\n",
|
||||
" def on_epoch_end(self):\n",
|
||||
" z = self.validation_z.type_as(self.generator.model[0].weight)\n",
|
||||
"\n",
|
||||
" # log sampled images\n",
|
||||
" sample_imgs = self(z)\n",
|
||||
" grid = torchvision.utils.make_grid(sample_imgs)\n",
|
||||
" self.logger.experiment.add_image('generated_images', grid, self.current_epoch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Ey5FmJPnzm_E"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dm = MNISTDataModule()\n",
|
||||
"model = GAN(*dm.size())\n",
|
||||
"trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n",
|
||||
"trainer.fit(model, dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "MlECc7cHzolp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start tensorboard.\n",
|
||||
"%load_ext tensorboard\n",
|
||||
"%tensorboard --logdir lightning_logs/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<code style=\"color:#792ee5;\">\n",
|
||||
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
|
||||
"</code>\n",
|
||||
"\n",
|
||||
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
|
||||
"\n",
|
||||
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
|
||||
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
|
||||
"\n",
|
||||
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
|
||||
"\n",
|
||||
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
|
||||
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
|
||||
"\n",
|
||||
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
||||
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
|
||||
"\n",
|
||||
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
||||
"\n",
|
||||
"### Contributions !\n",
|
||||
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
|
||||
"\n",
|
||||
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
||||
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
||||
"* You can also contribute your own notebooks with useful examples !\n",
|
||||
"\n",
|
||||
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
|
||||
"\n",
|
||||
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"include_colab_link": true,
|
||||
"name": "03-basic-gan.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue