369 lines
13 KiB
Plaintext
369 lines
13 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "06-mnist-tpu-training.ipynb",
|
|
"provenance": [],
|
|
"collapsed_sections": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"accelerator": "TPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WsWdLFMVKqbi"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-tpu-training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "qXO1QLkbRXl0"
|
|
},
|
|
"source": [
|
|
"# TPU training with PyTorch Lightning ⚡\n",
|
|
"\n",
|
|
"In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n",
|
|
"\n",
|
|
"The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/tpu.html).\n",
|
|
"\n",
|
|
"---\n",
|
|
"\n",
|
|
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
|
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
|
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
|
|
" - Ask a question on our [GitHub Discussions](https://github.com/PyTorchLightning/pytorch-lightning/discussions/)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "UmKX0Qa1RaLL"
|
|
},
|
|
"source": [
|
|
"### Setup\n",
|
|
"\n",
|
|
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "vAWOr0FZRaIj"
|
|
},
|
|
"source": [
|
|
"! pip install pytorch-lightning -qU"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "zepCr1upT4Z3"
|
|
},
|
|
"source": [
|
|
"### Install Colab TPU compatible PyTorch/TPU wheels and dependencies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "AYGWh10lRaF1"
|
|
},
|
|
"source": [
|
|
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "SNHa7DpmRZ-C"
|
|
},
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from torch.utils.data import random_split, DataLoader\n",
|
|
"\n",
|
|
"# Note - you must have torchvision installed for this example\n",
|
|
"from torchvision.datasets import MNIST\n",
|
|
"from torchvision import transforms\n",
|
|
"\n",
|
|
"import pytorch_lightning as pl\n",
|
|
"from pytorch_lightning.metrics.functional import accuracy"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rjo1dqzGUxt6"
|
|
},
|
|
"source": [
|
|
"### Defining The `MNISTDataModule`\n",
|
|
"\n",
|
|
"Below we define `MNISTDataModule`. You can learn more about datamodules in [docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html) and [datamodule notebook](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "pkbrm3YgUxlE"
|
|
},
|
|
"source": [
|
|
"class MNISTDataModule(pl.LightningDataModule):\n",
|
|
"\n",
|
|
" def __init__(self, data_dir: str = './'):\n",
|
|
" super().__init__()\n",
|
|
" self.data_dir = data_dir\n",
|
|
" self.transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
|
" ])\n",
|
|
"\n",
|
|
" # self.dims is returned when you call dm.size()\n",
|
|
" # Setting default dims here because we know them.\n",
|
|
" # Could optionally be assigned dynamically in dm.setup()\n",
|
|
" self.dims = (1, 28, 28)\n",
|
|
" self.num_classes = 10\n",
|
|
"\n",
|
|
" def prepare_data(self):\n",
|
|
" # download\n",
|
|
" MNIST(self.data_dir, train=True, download=True)\n",
|
|
" MNIST(self.data_dir, train=False, download=True)\n",
|
|
"\n",
|
|
" def setup(self, stage=None):\n",
|
|
"\n",
|
|
" # Assign train/val datasets for use in dataloaders\n",
|
|
" if stage == 'fit' or stage is None:\n",
|
|
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
|
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
|
"\n",
|
|
" # Assign test dataset for use in dataloader(s)\n",
|
|
" if stage == 'test' or stage is None:\n",
|
|
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
|
"\n",
|
|
" def train_dataloader(self):\n",
|
|
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
|
"\n",
|
|
" def val_dataloader(self):\n",
|
|
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
|
"\n",
|
|
" def test_dataloader(self):\n",
|
|
" return DataLoader(self.mnist_test, batch_size=32)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "nr9AqDWxUxdK"
|
|
},
|
|
"source": [
|
|
"### Defining the `LitModel`\n",
|
|
"\n",
|
|
"Below, we define the model `LitMNIST`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YKt0KZkOUxVY"
|
|
},
|
|
"source": [
|
|
"class LitModel(pl.LightningModule):\n",
|
|
" \n",
|
|
" def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
|
|
"\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.save_hyperparameters()\n",
|
|
"\n",
|
|
" self.model = nn.Sequential(\n",
|
|
" nn.Flatten(),\n",
|
|
" nn.Linear(channels * width * height, hidden_size),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Dropout(0.1),\n",
|
|
" nn.Linear(hidden_size, hidden_size),\n",
|
|
" nn.ReLU(),\n",
|
|
" nn.Dropout(0.1),\n",
|
|
" nn.Linear(hidden_size, num_classes)\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.model(x)\n",
|
|
" return F.log_softmax(x, dim=1)\n",
|
|
"\n",
|
|
" def training_step(self, batch, batch_idx):\n",
|
|
" x, y = batch\n",
|
|
" logits = self(x)\n",
|
|
" loss = F.nll_loss(logits, y)\n",
|
|
" self.log('train_loss', loss, prog_bar=False)\n",
|
|
" return loss\n",
|
|
"\n",
|
|
" def validation_step(self, batch, batch_idx):\n",
|
|
" x, y = batch\n",
|
|
" logits = self(x)\n",
|
|
" loss = F.nll_loss(logits, y)\n",
|
|
" preds = torch.argmax(logits, dim=1)\n",
|
|
" acc = accuracy(preds, y)\n",
|
|
" self.log('val_loss', loss, prog_bar=True)\n",
|
|
" self.log('val_acc', acc, prog_bar=True)\n",
|
|
" return loss\n",
|
|
"\n",
|
|
" def configure_optimizers(self):\n",
|
|
" optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n",
|
|
" return optimizer"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Uxl88z06cHyV"
|
|
},
|
|
"source": [
|
|
"### TPU Training\n",
|
|
"\n",
|
|
"Lightning supports training on a single TPU core or 8 TPU cores.\n",
|
|
"\n",
|
|
"The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n",
|
|
"\n",
|
|
"For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "UZ647Xg2gYng"
|
|
},
|
|
"source": [
|
|
"Train on TPU core ID 5 with `tpu_cores=[5]`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "bzhJ8g_vUxN2"
|
|
},
|
|
"source": [
|
|
"# Init DataModule\n",
|
|
"dm = MNISTDataModule()\n",
|
|
"# Init model from datamodule's attributes\n",
|
|
"model = LitModel(*dm.size(), dm.num_classes)\n",
|
|
"# Init trainer\n",
|
|
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n",
|
|
"# Train\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "slMq_0XBglzC"
|
|
},
|
|
"source": [
|
|
"Train on single TPU core with `tpu_cores=1`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "31N5Scf2RZ61"
|
|
},
|
|
"source": [
|
|
"# Init DataModule\n",
|
|
"dm = MNISTDataModule()\n",
|
|
"# Init model from datamodule's attributes\n",
|
|
"model = LitModel(*dm.size(), dm.num_classes)\n",
|
|
"# Init trainer\n",
|
|
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n",
|
|
"# Train\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_v8xcU5Sf_Cv"
|
|
},
|
|
"source": [
|
|
"Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "EFEw7YpLf-gE"
|
|
},
|
|
"source": [
|
|
"# Init DataModule\n",
|
|
"dm = MNISTDataModule()\n",
|
|
"# Init model from datamodule's attributes\n",
|
|
"model = LitModel(*dm.size(), dm.num_classes)\n",
|
|
"# Init trainer\n",
|
|
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n",
|
|
"# Train\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "m2mhgEgpRZ1g"
|
|
},
|
|
"source": [
|
|
"<code style=\"color:#792ee5;\">\n",
|
|
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
|
|
"</code>\n",
|
|
"\n",
|
|
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
|
|
"\n",
|
|
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
|
|
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
|
|
"\n",
|
|
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
|
|
"\n",
|
|
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
|
|
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
|
|
"\n",
|
|
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
|
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
|
|
"\n",
|
|
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
|
|
"\n",
|
|
"### Contributions !\n",
|
|
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
|
|
"\n",
|
|
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
|
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
|
|
"* You can also contribute your own notebooks with useful examples !\n",
|
|
"\n",
|
|
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
|
|
"\n",
|
|
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_static/images/logo.png?raw=true\" width=\"800\" height=\"200\" />"
|
|
]
|
|
}
|
|
]
|
|
}
|