{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "06-mnist-tpu-training.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "WsWdLFMVKqbi"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qXO1QLkbRXl0"
},
"source": [
"# TPU training with PyTorch Lightning ⚡\n",
"\n",
"In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n",
"\n",
"The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/tpu.html).\n",
"\n",
"---\n",
"\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
" - Ask a question on our [official forum](https://forums.pytorchlightning.ai/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UmKX0Qa1RaLL"
},
"source": [
"### Setup\n",
"\n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vAWOr0FZRaIj"
},
"source": [
"! pip install pytorch-lightning -qU"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zepCr1upT4Z3"
},
"source": [
"### Install Colab TPU compatible PyTorch/TPU wheels and dependencies"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AYGWh10lRaF1"
},
"source": [
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SNHa7DpmRZ-C"
},
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import random_split, DataLoader\n",
"\n",
"# Note - you must have torchvision installed for this example\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.metrics.functional import accuracy"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rjo1dqzGUxt6"
},
"source": [
"### Defining The `MNISTDataModule`\n",
"\n",
"Below we define `MNISTDataModule`. You can learn more about datamodules in [docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html) and [datamodule notebook](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "pkbrm3YgUxlE"
},
"source": [
"class MNISTDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, data_dir: str = './'):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # self.dims is returned when you call dm.size()\n",
" # Setting default dims here because we know them.\n",
" # Could optionally be assigned dynamically in dm.setup()\n",
" self.dims = (1, 28, 28)\n",
" self.num_classes = 10\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=32)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=32)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=32)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nr9AqDWxUxdK"
},
"source": [
"### Defining the `LitModel`\n",
"\n",
"Below, we define the model `LitMNIST`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "YKt0KZkOUxVY"
},
"source": [
"class LitModel(pl.LightningModule):\n",
" \n",
" def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
"\n",
" super().__init__()\n",
"\n",
" self.save_hyperparameters()\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(channels * width * height, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_size, num_classes)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" self.log('train_loss', loss, prog_bar=False)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n",
" return optimizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uxl88z06cHyV"
},
"source": [
"### TPU Training\n",
"\n",
"Lightning supports training on a single TPU core or 8 TPU cores.\n",
"\n",
"The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n",
"\n",
"For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UZ647Xg2gYng"
},
"source": [
"Train on TPU core ID 5 with `tpu_cores=[5]`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "bzhJ8g_vUxN2"
},
"source": [
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n",
"# Train\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "slMq_0XBglzC"
},
"source": [
"Train on single TPU core with `tpu_cores=1`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "31N5Scf2RZ61"
},
"source": [
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n",
"# Train\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_v8xcU5Sf_Cv"
},
"source": [
"Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core."
]
},
{
"cell_type": "code",
"metadata": {
"id": "EFEw7YpLf-gE"
},
"source": [
"# Init DataModule\n",
"dm = MNISTDataModule()\n",
"# Init model from datamodule's attributes\n",
"model = LitModel(*dm.size(), dm.num_classes)\n",
"# Init trainer\n",
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n",
"# Train\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "m2mhgEgpRZ1g"
},
"source": [
"\n",
" Congratulations - Time to Join the Community!
\n",
"
\n",
"\n",
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
"\n",
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
"\n",
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
"\n",
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
"\n",
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
"\n",
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"\n",
"### Contributions !\n",
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
"\n",
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* You can also contribute your own notebooks with useful examples !\n",
"\n",
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
"\n",
""
]
}
]
}