2020-02-17 21:01:20 +00:00
TPU support
Lightning supports running on TPUs. At this moment, TPUs are only available
on Google Cloud (GCP). For more information on TPUs
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
Live demo
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
TPU Terminology
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
core is optimized for 128x128 matrix multiplies. In general, a single
TPU is about as fast as 5 V100 GPUs!
A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
You can request a full pod from Google cloud or a "slice" which gives you
some subset of those 2048 cores.
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
How to access TPUs
To access TPUs there are two main ways.
1. Using google colab.
2. Using Google Cloud (GCP).
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
Colab TPUs
Colab is like a jupyter notebook with a free GPU or TPU
hosted on GCP.
To get a TPU on colab, follow these steps:
2020-03-20 19:49:01 +00:00
1. Go to `https://colab.research.google.com/ <https://colab.research.google.com/>`_.
2. Click "new notebook" (bottom right of pop-up).
3. Click runtime > change runtime settings. Select Python 3, and hardware accelerator "TPU".
This will give you a TPU with 8 cores.
4. Next, insert this code into the first cell and execute.
This will install the xla library that interfaces between PyTorch and the TPU.
.. code-block:: python
import collections
from datetime import datetime, timedelta
import os
import requests
import threading
_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"]
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
# Update TPU XRT version
def update_server_xrt():
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
print('Done updating server-side XRT: {}'.format(requests.post(url)))
update = threading.Thread(target=update_server_xrt)
.. code-block::
# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
.. code-block::
!pip install pytorch-lightning
2020-02-17 21:01:20 +00:00
6. Then set up your LightningModule as normal.
2020-03-17 00:50:14 +00:00
Lightning automatically inserts the correct samplers - no need to do this yourself!
Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right
chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning
.. note:: Don't add distributedSamplers. Lightning does this automatically
If for some reason you still need to, this is how to construct the sampler
for TPU use
2020-02-17 21:01:20 +00:00
.. code-block:: python
import torch_xla.core.xla_model as xm
def train_dataloader(self):
dataset = MNIST(
# required for TPU support
sampler = None
if use_tpu:
sampler = torch.utils.data.distributed.DistributedSampler(
loader = DataLoader(
return loader
2020-03-20 19:49:01 +00:00
Configure the number of TPU cores in the trainer. You can only choose 1 or 8.
To use a full TPU pod skip to the TPU pod section.
2020-02-17 21:01:20 +00:00
.. code-block:: python
import pytorch_lightning as pl
my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
That's it! Your model will train on all 8 TPU cores.
2020-03-17 00:50:14 +00:00
Distributed Backend with TPU
The ```distributed_backend``` option used for GPUs does not apply to TPUs.
TPUs work in DDP mode by default (distributing over each core)
2020-02-17 21:01:20 +00:00
To train on more than 8 cores, your code actually doesn't change!
All you need to do is submit the following command:
.. code-block:: bash
2020-02-17 22:52:42 +00:00
2020-02-17 21:01:20 +00:00
$ python -m torch_xla.distributed.xla_dist
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
16 bit precision
Lightning also supports training in 16-bit precision with TPUs.
By default, TPU training will use 32-bit precision. To enable 16-bit, also
set the 16-bit flag.
.. code-block:: python
import pytorch_lightning as pl
my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
2020-03-17 00:50:14 +00:00
2020-02-17 21:01:20 +00:00
About XLA
XLA is the library that interfaces PyTorch with the TPUs.
For more information check out `XLA <https://github.com/pytorch/xla>`_.
2020-02-25 03:30:53 +00:00
Guide for `troubleshooting XLA <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md>`_