71 lines
2.0 KiB
ReStructuredText
71 lines
2.0 KiB
ReStructuredText
:orphan:
|
|
|
|
TPU training (Intermediate)
|
|
===========================
|
|
**Audience:** Users looking to use cloud TPUs.
|
|
|
|
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
|
|
|
----
|
|
|
|
DistributedSamplers
|
|
-------------------
|
|
Lightning automatically inserts the correct samplers - no need to do this yourself!
|
|
|
|
Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right
|
|
chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning
|
|
|
|
.. note:: Don't add distributedSamplers. Lightning does this automatically
|
|
|
|
If for some reason you still need to, this is how to construct the sampler
|
|
for TPU use
|
|
|
|
.. code-block:: python
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
|
|
def train_dataloader(self):
|
|
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
|
|
|
|
# required for TPU support
|
|
sampler = None
|
|
if use_tpu:
|
|
sampler = torch.utils.data.distributed.DistributedSampler(
|
|
dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
|
|
)
|
|
|
|
loader = DataLoader(dataset, sampler=sampler, batch_size=32)
|
|
|
|
return loader
|
|
|
|
Configure the number of TPU cores in the trainer. You can only choose 1 or 8.
|
|
To use a full TPU pod skip to the TPU pod section.
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
|
|
my_model = MyLightningModule()
|
|
trainer = L.Trainer(accelerator="tpu", devices=8)
|
|
trainer.fit(my_model)
|
|
|
|
That's it! Your model will train on all 8 TPU cores.
|
|
|
|
----------------
|
|
|
|
16 bit precision
|
|
----------------
|
|
Lightning also supports training in 16-bit precision with TPUs.
|
|
By default, TPU training will use 32-bit precision. To enable it, do
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
|
|
my_model = MyLightningModule()
|
|
trainer = L.Trainer(accelerator="tpu", precision="16-true")
|
|
trainer.fit(my_model)
|
|
|
|
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
|