"cells": [
<a href="" target="_parent"><img src="" alt="Open In Colab"/></a>
"# TPU training with PyTorch Lightning ⚡\n",
"In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n",
"The most up to documentation related to TPU training can be found [here](\n",
" - Give us a ⭐ [on Github](\n",
" - Check out [the documentation](\n",
" - Join us [on Slack](\n",
" - Ask a question on our [GitHub Discussions]("
"### Setup\n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
"! pip install pytorch-lightning -qU"
### Install Colab TPU compatible PyTorch/TPU wheels and dependencies
"! pip install cloud-tpu-client==0.10"
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from import random_split, DataLoader\n",
"# Note - you must have torchvision installed for this example\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.metrics.functional import accuracy"
"### Defining The `MNISTDataModule`\n",
"Below we define `MNISTDataModule`. You can learn more about datamodules in [docs]( and [datamodule notebook]("
"class MNISTDataModule(pl.LightningDataModule):\n",
" def __init__(self, data_dir: str = './'):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" # self.dims is returned when you call dm.size()\n",
" # Setting default dims here because we know them.\n",
" # Could optionally be assigned dynamically in dm.setup()\n",
" self.dims = (1, 28, 28)\n",
" self.num_classes = 10\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
" def setup(self, stage=None):\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=32)\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=32)\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=32)"
"### Defining the `LitModel`\n",
"Below, we define the model `LitMNIST`."
"class LitModel(pl.LightningModule):\n",
" \n",
" def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(channels * width * height, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, num_classes)\n",
" )\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return F.log_softmax(x, dim=1)\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" self.log('train_loss', loss, prog_bar=False)\n",
" return loss\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n",
" return optimizer"
"### TPU Training\n",
"Lightning supports training on a single TPU core or 8 TPU cores.\n",
"The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n",
"For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5."
Train on TPU core ID 5 with `tpu_cores=[5]`.
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n",
"# Train\n",
", dm)"
Train on single TPU core with `tpu_cores=1`.
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n",
"# Train\n",
", dm)"
Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core.
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n",
"# Train\n",
", dm)"
