Update docs for devices flag (#10293)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Kaushik B 2021-11-02 00:09:00 +05:30 committed by GitHub
parent 10edc6de6b
commit 1127b28bbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 1 deletions

View File

@ -543,6 +543,40 @@ will need to be set up to use remote filepaths.
# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())
devices
^^^^^^^
Number of devices to train on (``int``), which devices to train on (``list`` or ``str``), or ``"auto"``.
It will be mapped to either ``gpus``, ``tpu_cores``, ``num_processes`` or ``ipus``,
based on the accelerator type (``"cpu", "gpu", "tpu", "ipu", "auto"``).
.. code-block:: python
# Training with CPU Accelerator using 2 processes
trainer = Trainer(devices=2, accelerator="cpu")
# Training with GPU Accelerator using GPUs 1 and 3
trainer = Trainer(devices=[1, 3], accelerator="gpu")
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")
.. tip:: The ``"auto"`` option recognizes the devices to train on, depending on the ``Accelerator`` being used.
.. code-block:: python
# If your machine has GPUs, it will use all the available GPUs for training
trainer = Trainer(devices="auto", accelerator="auto")
# Training with CPU Accelerator using 1 process
trainer = Trainer(devices="auto", accelerator="cpu")
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices="auto", accelerator="tpu")
# Training with IPU Accelerator using 4 ipus
trainer = Trainer(devices="auto", accelerator="ipu")
enable_checkpointing
^^^^^^^^^^^^^^^^^^^^
@ -1179,7 +1213,7 @@ Half precision, or mixed precision, is the combined use of 32 and 16 bit floatin
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" https://github.com/NVIDIA/apex
2. Set the `precision` trainer flag to 16. You can customize the `Apex optimization level <https://nvidia.github.io/apex/amp.html#opt-levels>`_ by setting the `amp_level` flag.
2. Set the ``precision`` trainer flag to 16. You can customize the `Apex optimization level <https://nvidia.github.io/apex/amp.html#opt-levels>`_ by setting the `amp_level` flag.
.. testcode::
:skipif: not _APEX_AVAILABLE or not torch.cuda.is_available()