{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "06-mnist-tpu-training.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "TPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "WsWdLFMVKqbi" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "qXO1QLkbRXl0" }, "source": [ "# TPU training with PyTorch Lightning ⚡\n", "\n", "In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n", "\n", "The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/tpu.html).\n", "\n", "---\n", "\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", " - Ask a question on our [official forum](https://forums.pytorchlightning.ai/)" ] }, { "cell_type": "markdown", "metadata": { "id": "UmKX0Qa1RaLL" }, "source": [ "### Setup\n", "\n", "Lightning is easy to install. Simply ```pip install pytorch-lightning```" ] }, { "cell_type": "code", "metadata": { "id": "vAWOr0FZRaIj" }, "source": [ "! pip install pytorch-lightning -qU" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "zepCr1upT4Z3" }, "source": [ "### Install Colab TPU compatible PyTorch/TPU wheels and dependencies" ] }, { "cell_type": "code", "metadata": { "id": "AYGWh10lRaF1" }, "source": [ "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SNHa7DpmRZ-C" }, "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import random_split, DataLoader\n", "\n", "# Note - you must have torchvision installed for this example\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.metrics.functional import accuracy" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "rjo1dqzGUxt6" }, "source": [ "### Defining The `MNISTDataModule`\n", "\n", "Below we define `MNISTDataModule`. You can learn more about datamodules in [docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html) and [datamodule notebook](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)." ] }, { "cell_type": "code", "metadata": { "id": "pkbrm3YgUxlE" }, "source": [ "class MNISTDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, data_dir: str = './'):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", " # self.dims is returned when you call dm.size()\n", " # Setting default dims here because we know them.\n", " # Could optionally be assigned dynamically in dm.setup()\n", " self.dims = (1, 28, 28)\n", " self.num_classes = 10\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == 'fit' or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == 'test' or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=32)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=32)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=32)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "nr9AqDWxUxdK" }, "source": [ "### Defining the `LitModel`\n", "\n", "Below, we define the model `LitMNIST`." ] }, { "cell_type": "code", "metadata": { "id": "YKt0KZkOUxVY" }, "source": [ "class LitModel(pl.LightningModule):\n", " \n", " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " self.save_hyperparameters()\n", "\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " self.log('train_loss', loss, prog_bar=False)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", " self.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n", " return optimizer" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Uxl88z06cHyV" }, "source": [ "### TPU Training\n", "\n", "Lightning supports training on a single TPU core or 8 TPU cores.\n", "\n", "The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n", "\n", "For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5." ] }, { "cell_type": "markdown", "metadata": { "id": "UZ647Xg2gYng" }, "source": [ "Train on TPU core ID 5 with `tpu_cores=[5]`." ] }, { "cell_type": "code", "metadata": { "id": "bzhJ8g_vUxN2" }, "source": [ "# Init DataModule\n", "dm = MNISTDataModule()\n", "# Init model from datamodule's attributes\n", "model = LitModel(*dm.size(), dm.num_classes)\n", "# Init trainer\n", "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n", "# Train\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "slMq_0XBglzC" }, "source": [ "Train on single TPU core with `tpu_cores=1`." ] }, { "cell_type": "code", "metadata": { "id": "31N5Scf2RZ61" }, "source": [ "# Init DataModule\n", "dm = MNISTDataModule()\n", "# Init model from datamodule's attributes\n", "model = LitModel(*dm.size(), dm.num_classes)\n", "# Init trainer\n", "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n", "# Train\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "_v8xcU5Sf_Cv" }, "source": [ "Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core." ] }, { "cell_type": "code", "metadata": { "id": "EFEw7YpLf-gE" }, "source": [ "# Init DataModule\n", "dm = MNISTDataModule()\n", "# Init model from datamodule's attributes\n", "model = LitModel(*dm.size(), dm.num_classes)\n", "# Init trainer\n", "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n", "# Train\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "m2mhgEgpRZ1g" }, "source": [ "\n", "

Congratulations - Time to Join the Community!

\n", "
\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", "\n", "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", "\n", "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", "\n", "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", "\n", "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", "\n", "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "" ] } ] }