Add `TPUPrecisionPlugin` (#10020)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-10-19 19:48:57 +02:00 committed by GitHub
parent 4aaca17fce
commit e8beceb631
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 79 additions and 17 deletions

View File

@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))

View File

@ -170,12 +170,16 @@ Precision Plugins
:template: classtemplate.rst
PrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
ShardedNativeMixedPrecisionPlugin
ApexMixedPrecisionPlugin
DeepSpeedPrecisionPlugin
TPUPrecisionPlugin
TPUHalfPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
IPUPrecisionPlugin
Cluster Environments
^^^^^^^^^^^^^^^^^^^^

View File

@ -131,12 +131,16 @@ Precision Plugins
:template: classtemplate.rst
PrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
ShardedNativeMixedPrecisionPlugin
ApexMixedPrecisionPlugin
DeepSpeedPrecisionPlugin
TPUPrecisionPlugin
TPUHalfPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
IPUPrecisionPlugin
Cluster Environments

View File

@ -18,12 +18,11 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -35,18 +34,19 @@ class TPUAccelerator(Accelerator):
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with TPU.
MisconfigurationException:
If TPUs are not using a single TPU core or TPU spawn training.
ValueError:
If the precision or training type plugin are unsupported.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
# this configuration should have been avoided in the accelerator connector
raise ValueError(
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
)
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
f" found {self.training_type_plugin}."
)
return super().setup(trainer)
def run_optimizer_step(

View File

@ -16,6 +16,7 @@ from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
@ -57,6 +58,7 @@ __all__ = [
"FullyShardedNativeMixedPrecisionPlugin",
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUPrecisionPlugin",
"TPUHalfPrecisionPlugin",
"TPUSpawnPlugin",
"TrainingTypePlugin",

View File

@ -4,8 +4,10 @@ from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin #
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401

View File

@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
class TPUPrecisionPlugin(PrecisionPlugin):
...

View File

@ -17,10 +17,10 @@ from typing import Any, List, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
class TPUHalfPrecisionPlugin(PrecisionPlugin):
class TPUHalfPrecisionPlugin(TPUPrecisionPlugin):
"""Plugin that enables bfloats on TPUs."""
precision: int = 16

View File

@ -47,6 +47,7 @@ from pytorch_lightning.plugins import (
SingleDevicePlugin,
SingleTPUPlugin,
TPUHalfPrecisionPlugin,
TPUPrecisionPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
TrainingTypePluginsRegistry,
@ -592,6 +593,17 @@ class AcceleratorConnector:
if self.use_ipu:
return IPUPrecisionPlugin(self.precision)
if self.use_tpu:
if self.precision == 32:
return TPUPrecisionPlugin()
elif self.precision == 64:
raise MisconfigurationException(
"`Trainer(accelerator='tpu', precision=64)` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" requesting this feature."
)
elif self.precision in (16, "bf16"):
return TPUHalfPrecisionPlugin()
if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
return DeepSpeedPrecisionPlugin(self.precision)
@ -601,9 +613,6 @@ class AcceleratorConnector:
if self.precision == 64:
return DoublePrecisionPlugin()
if self.precision in (16, "bf16"):
if self.use_tpu:
return TPUHalfPrecisionPlugin()
if self.amp_type == AMPType.NATIVE:
if self.amp_level is not None:
raise MisconfigurationException(

View File

@ -976,3 +976,13 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
with pytest.raises(SystemExit):
trainer.fit(model)
def test_unsupported_tpu_choice(monkeypatch):
import pytorch_lightning.utilities.imports as imports
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
monkeypatch.setattr(imports, "_XLA_AVAILABLE", True)
monkeypatch.setattr(AcceleratorConnector, "has_tpu", True)
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
Trainer(accelerator="tpu", precision=64)

View File

@ -22,7 +22,7 @@ from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
@ -284,3 +284,13 @@ def test_auto_parameters_tying_tpus_nested_module(tmpdir):
trainer.fit(model)
assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))
def test_tpu_invalid_raises():
accelerator = TPUAccelerator(object(), TPUSpawnPlugin())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
accelerator.setup(object())
accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
accelerator.setup(object())