Add `TPUPrecisionPlugin` (#10020)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
4aaca17fce
commit
e8beceb631
|
@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
|
||||
|
||||
|
||||
- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))
|
||||
|
||||
|
||||
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))
|
||||
|
||||
|
||||
|
|
|
@ -170,12 +170,16 @@ Precision Plugins
|
|||
:template: classtemplate.rst
|
||||
|
||||
PrecisionPlugin
|
||||
MixedPrecisionPlugin
|
||||
NativeMixedPrecisionPlugin
|
||||
ShardedNativeMixedPrecisionPlugin
|
||||
ApexMixedPrecisionPlugin
|
||||
DeepSpeedPrecisionPlugin
|
||||
TPUPrecisionPlugin
|
||||
TPUHalfPrecisionPlugin
|
||||
DoublePrecisionPlugin
|
||||
FullyShardedNativeMixedPrecisionPlugin
|
||||
IPUPrecisionPlugin
|
||||
|
||||
Cluster Environments
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -131,12 +131,16 @@ Precision Plugins
|
|||
:template: classtemplate.rst
|
||||
|
||||
PrecisionPlugin
|
||||
MixedPrecisionPlugin
|
||||
NativeMixedPrecisionPlugin
|
||||
ShardedNativeMixedPrecisionPlugin
|
||||
ApexMixedPrecisionPlugin
|
||||
DeepSpeedPrecisionPlugin
|
||||
TPUPrecisionPlugin
|
||||
TPUHalfPrecisionPlugin
|
||||
DoublePrecisionPlugin
|
||||
FullyShardedNativeMixedPrecisionPlugin
|
||||
IPUPrecisionPlugin
|
||||
|
||||
|
||||
Cluster Environments
|
||||
|
|
|
@ -18,12 +18,11 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -35,18 +34,19 @@ class TPUAccelerator(Accelerator):
|
|||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If AMP is used with TPU.
|
||||
MisconfigurationException:
|
||||
If TPUs are not using a single TPU core or TPU spawn training.
|
||||
ValueError:
|
||||
If the precision or training type plugin are unsupported.
|
||||
"""
|
||||
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
|
||||
raise MisconfigurationException(
|
||||
"amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
|
||||
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
|
||||
# this configuration should have been avoided in the accelerator connector
|
||||
raise ValueError(
|
||||
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
|
||||
)
|
||||
|
||||
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
|
||||
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
|
||||
raise ValueError(
|
||||
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
|
||||
f" found {self.training_type_plugin}."
|
||||
)
|
||||
return super().setup(trainer)
|
||||
|
||||
def run_optimizer_step(
|
||||
|
|
|
@ -16,6 +16,7 @@ from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin
|
|||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
|
||||
|
@ -57,6 +58,7 @@ __all__ = [
|
|||
"FullyShardedNativeMixedPrecisionPlugin",
|
||||
"SingleDevicePlugin",
|
||||
"SingleTPUPlugin",
|
||||
"TPUPrecisionPlugin",
|
||||
"TPUHalfPrecisionPlugin",
|
||||
"TPUSpawnPlugin",
|
||||
"TrainingTypePlugin",
|
||||
|
|
|
@ -4,8 +4,10 @@ from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin #
|
|||
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
|
||||
FullyShardedNativeMixedPrecisionPlugin,
|
||||
)
|
||||
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
|
||||
|
||||
class TPUPrecisionPlugin(PrecisionPlugin):
|
||||
...
|
|
@ -17,10 +17,10 @@ from typing import Any, List, Tuple
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
|
||||
|
||||
|
||||
class TPUHalfPrecisionPlugin(PrecisionPlugin):
|
||||
class TPUHalfPrecisionPlugin(TPUPrecisionPlugin):
|
||||
"""Plugin that enables bfloats on TPUs."""
|
||||
|
||||
precision: int = 16
|
||||
|
|
|
@ -47,6 +47,7 @@ from pytorch_lightning.plugins import (
|
|||
SingleDevicePlugin,
|
||||
SingleTPUPlugin,
|
||||
TPUHalfPrecisionPlugin,
|
||||
TPUPrecisionPlugin,
|
||||
TPUSpawnPlugin,
|
||||
TrainingTypePlugin,
|
||||
TrainingTypePluginsRegistry,
|
||||
|
@ -592,6 +593,17 @@ class AcceleratorConnector:
|
|||
|
||||
if self.use_ipu:
|
||||
return IPUPrecisionPlugin(self.precision)
|
||||
if self.use_tpu:
|
||||
if self.precision == 32:
|
||||
return TPUPrecisionPlugin()
|
||||
elif self.precision == 64:
|
||||
raise MisconfigurationException(
|
||||
"`Trainer(accelerator='tpu', precision=64)` is not implemented."
|
||||
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
|
||||
" requesting this feature."
|
||||
)
|
||||
elif self.precision in (16, "bf16"):
|
||||
return TPUHalfPrecisionPlugin()
|
||||
|
||||
if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
|
||||
return DeepSpeedPrecisionPlugin(self.precision)
|
||||
|
@ -601,9 +613,6 @@ class AcceleratorConnector:
|
|||
if self.precision == 64:
|
||||
return DoublePrecisionPlugin()
|
||||
if self.precision in (16, "bf16"):
|
||||
if self.use_tpu:
|
||||
return TPUHalfPrecisionPlugin()
|
||||
|
||||
if self.amp_type == AMPType.NATIVE:
|
||||
if self.amp_level is not None:
|
||||
raise MisconfigurationException(
|
||||
|
|
|
@ -976,3 +976,13 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
|
|||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_unsupported_tpu_choice(monkeypatch):
|
||||
import pytorch_lightning.utilities.imports as imports
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
|
||||
|
||||
monkeypatch.setattr(imports, "_XLA_AVAILABLE", True)
|
||||
monkeypatch.setattr(AcceleratorConnector, "has_tpu", True)
|
||||
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
|
||||
Trainer(accelerator="tpu", precision=64)
|
||||
|
|
|
@ -22,7 +22,7 @@ from torch import nn
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.accelerators.tpu import TPUAccelerator
|
||||
from pytorch_lightning.plugins import TPUSpawnPlugin
|
||||
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin
|
||||
from pytorch_lightning.utilities import find_shared_parameters
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
@ -284,3 +284,13 @@ def test_auto_parameters_tying_tpus_nested_module(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))
|
||||
|
||||
|
||||
def test_tpu_invalid_raises():
|
||||
accelerator = TPUAccelerator(object(), TPUSpawnPlugin())
|
||||
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
|
||||
accelerator.setup(object())
|
||||
|
||||
accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
|
||||
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
|
||||
accelerator.setup(object())
|
||||
|
|
Loading…
Reference in New Issue