parent
c89e482f85
commit
49d000c0c9
|
@ -5,10 +5,14 @@ Lightning supports running on TPUs. At this moment, TPUs are only available
|
|||
on Google Cloud (GCP). For more information on TPUs
|
||||
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.
|
||||
|
||||
---------------
|
||||
|
||||
Live demo
|
||||
----------
|
||||
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.
|
||||
|
||||
---------------
|
||||
|
||||
TPU Terminology
|
||||
---------------
|
||||
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
|
||||
|
@ -19,6 +23,8 @@ A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
|
|||
You can request a full pod from Google cloud or a "slice" which gives you
|
||||
some subset of those 2048 cores.
|
||||
|
||||
---------------
|
||||
|
||||
How to access TPUs
|
||||
-------------------
|
||||
To access TPUs there are two main ways.
|
||||
|
@ -26,6 +32,8 @@ To access TPUs there are two main ways.
|
|||
1. Using google colab.
|
||||
2. Using Google Cloud (GCP).
|
||||
|
||||
---------------
|
||||
|
||||
Colab TPUs
|
||||
-----------
|
||||
Colab is like a jupyter notebook with a free GPU or TPU
|
||||
|
@ -33,16 +41,16 @@ hosted on GCP.
|
|||
|
||||
To get a TPU on colab, follow these steps:
|
||||
|
||||
1. Go to https://colab.research.google.com/.
|
||||
1. Go to https://colab.research.google.com/.
|
||||
|
||||
2. Click "new notebook" (bottom right of pop-up).
|
||||
2. Click "new notebook" (bottom right of pop-up).
|
||||
|
||||
3. Click runtime > change runtime settings. Select Python 3,
|
||||
and hardware accelerator "TPU". This will give you a TPU with 8 cores.
|
||||
3. Click runtime > change runtime settings. Select Python 3,
|
||||
and hardware accelerator "TPU". This will give you a TPU with 8 cores.
|
||||
|
||||
4. Next, insert this code into the first cell and execute. This
|
||||
will install the xla library that interfaces between PyTorch and
|
||||
the TPU.
|
||||
4. Next, insert this code into the first cell and execute. This
|
||||
will install the xla library that interfaces between PyTorch and
|
||||
the TPU.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -86,7 +94,8 @@ the TPU.
|
|||
!pip install "$TORCHVISION_WHEEL"
|
||||
!sudo apt-get install libomp5
|
||||
update.join()
|
||||
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
|
||||
|
||||
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
@ -94,8 +103,19 @@ the TPU.
|
|||
|
||||
6. Then set up your LightningModule as normal.
|
||||
|
||||
7. TPUs require a DistributedSampler. That means you should change your
|
||||
train_dataloader (and val, train) code as follows.
|
||||
---------------
|
||||
|
||||
DistributedSamplers
|
||||
-------------------
|
||||
Lightning automatically inserts the correct samplers - no need to do this yourself!
|
||||
|
||||
Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right
|
||||
chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning
|
||||
|
||||
.. note:: Don't add distributedSamplers. Lightning does this automatically
|
||||
|
||||
If for some reason you still need to, this is how to construct the sampler
|
||||
for TPU use
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -140,6 +160,15 @@ train_dataloader (and val, train) code as follows.
|
|||
|
||||
That's it! Your model will train on all 8 TPU cores.
|
||||
|
||||
---------------
|
||||
|
||||
Distributed Backend with TPU
|
||||
----------------------------
|
||||
The ```distributed_backend``` option used for GPUs does not apply to TPUs.
|
||||
TPUs work in DDP mode by default (distributing over each core)
|
||||
|
||||
---------------
|
||||
|
||||
TPU Pod
|
||||
--------
|
||||
To train on more than 8 cores, your code actually doesn't change!
|
||||
|
@ -152,6 +181,8 @@ All you need to do is submit the following command:
|
|||
--conda-env=torch-xla-nightly
|
||||
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
|
||||
|
||||
---------------
|
||||
|
||||
16 bit precision
|
||||
-----------------
|
||||
Lightning also supports training in 16-bit precision with TPUs.
|
||||
|
@ -168,6 +199,7 @@ set the 16-bit flag.
|
|||
|
||||
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
|
||||
|
||||
---------------
|
||||
|
||||
About XLA
|
||||
----------
|
||||
|
|
|
@ -107,9 +107,10 @@ Training loop structure
|
|||
-----------------------
|
||||
|
||||
The general pattern is that each loop (training, validation, test loop)
|
||||
has 2 methods:
|
||||
has 3 methods:
|
||||
|
||||
- ``` ___step ```
|
||||
- ``` ___step_end ```
|
||||
- ``` ___epoch_end```
|
||||
|
||||
To show how lightning calls these, let's use the validation loop as an example
|
||||
|
@ -126,6 +127,28 @@ To show how lightning calls these, let's use the validation loop as an example
|
|||
# like calculate validation set accuracy or loss
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
if we use dp or ddp2 mode, we can also define the ```XXX_step_end``` method to operate
|
||||
on all parts of the batch
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
val_outs = []
|
||||
for val_batch in val_data:
|
||||
batches = split_batch(val_batch)
|
||||
dp_outs = []
|
||||
for sub_batch in batches:
|
||||
dp_out = validation_step(sub_batch)
|
||||
dp_outs.append(dp_out)
|
||||
|
||||
out = validation_step_end(dp_outs)
|
||||
val_outs.append(out)
|
||||
|
||||
# do something with the outputs for all batches
|
||||
# like calculate validation set accuracy or loss
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
.. note:: ```training_step_end``` is not available yet but coming in the next release.
|
||||
|
||||
Add validation loop
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -146,7 +146,8 @@ Example::
|
|||
callbacks
|
||||
^^^^^^^^^
|
||||
|
||||
Add a list of user defined callbacks.
|
||||
Add a list of user defined callbacks. These callbacks DO NOT replace the explicit callbacks
|
||||
(loggers, EarlyStopping or ModelCheckpoint).
|
||||
|
||||
.. note:: Only user defined callbacks (ie: Not EarlyStopping or ModelCheckpoint)
|
||||
|
||||
|
@ -239,6 +240,8 @@ Example::
|
|||
# ddp2 = DistributedDataParallel + dp
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
||||
|
||||
.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
|
||||
|
||||
early_stop_callback
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
Loading…
Reference in New Issue