diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c09e4c9d8..7ec2c531a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020)) + + - Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965)) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 7bc4d8b460..931ca4fb78 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -170,12 +170,16 @@ Precision Plugins :template: classtemplate.rst PrecisionPlugin + MixedPrecisionPlugin NativeMixedPrecisionPlugin ShardedNativeMixedPrecisionPlugin ApexMixedPrecisionPlugin DeepSpeedPrecisionPlugin + TPUPrecisionPlugin TPUHalfPrecisionPlugin DoublePrecisionPlugin + FullyShardedNativeMixedPrecisionPlugin + IPUPrecisionPlugin Cluster Environments ^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index a7d8850520..38e5ed76f0 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -131,12 +131,16 @@ Precision Plugins :template: classtemplate.rst PrecisionPlugin + MixedPrecisionPlugin NativeMixedPrecisionPlugin ShardedNativeMixedPrecisionPlugin ApexMixedPrecisionPlugin DeepSpeedPrecisionPlugin + TPUPrecisionPlugin TPUHalfPrecisionPlugin DoublePrecisionPlugin + FullyShardedNativeMixedPrecisionPlugin + IPUPrecisionPlugin Cluster Environments diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 68925ab67a..b85a92794e 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -18,12 +18,11 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision import TPUPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.exceptions import MisconfigurationException if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm @@ -35,18 +34,19 @@ class TPUAccelerator(Accelerator): def setup(self, trainer: "pl.Trainer") -> None: """ Raises: - MisconfigurationException: - If AMP is used with TPU. - MisconfigurationException: - If TPUs are not using a single TPU core or TPU spawn training. + ValueError: + If the precision or training type plugin are unsupported. """ - if isinstance(self.precision_plugin, MixedPrecisionPlugin): - raise MisconfigurationException( - "amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin" + if not isinstance(self.precision_plugin, TPUPrecisionPlugin): + # this configuration should have been avoided in the accelerator connector + raise ValueError( + f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}." ) - if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): - raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") + raise ValueError( + "The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin," + f" found {self.training_type_plugin}." + ) return super().setup(trainer) def run_optimizer_step( diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 13f8c7404b..4ae1c9a52a 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -16,6 +16,7 @@ from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin @@ -57,6 +58,7 @@ __all__ = [ "FullyShardedNativeMixedPrecisionPlugin", "SingleDevicePlugin", "SingleTPUPlugin", + "TPUPrecisionPlugin", "TPUHalfPrecisionPlugin", "TPUSpawnPlugin", "TrainingTypePlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 904e5f9f44..05df370418 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -4,8 +4,10 @@ from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 FullyShardedNativeMixedPrecisionPlugin, ) +from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py new file mode 100644 index 0000000000..6df9404d82 --- /dev/null +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class TPUPrecisionPlugin(PrecisionPlugin): + ... diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py index 4e1db6210e..ecc5742ba4 100644 --- a/pytorch_lightning/plugins/precision/tpu_bfloat.py +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -17,10 +17,10 @@ from typing import Any, List, Tuple import torch.nn as nn from torch.optim import Optimizer -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.precision import TPUPrecisionPlugin -class TPUHalfPrecisionPlugin(PrecisionPlugin): +class TPUHalfPrecisionPlugin(TPUPrecisionPlugin): """Plugin that enables bfloats on TPUs.""" precision: int = 16 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 878b3bfb0d..6bf9a94172 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -47,6 +47,7 @@ from pytorch_lightning.plugins import ( SingleDevicePlugin, SingleTPUPlugin, TPUHalfPrecisionPlugin, + TPUPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, TrainingTypePluginsRegistry, @@ -592,6 +593,17 @@ class AcceleratorConnector: if self.use_ipu: return IPUPrecisionPlugin(self.precision) + if self.use_tpu: + if self.precision == 32: + return TPUPrecisionPlugin() + elif self.precision == 64: + raise MisconfigurationException( + "`Trainer(accelerator='tpu', precision=64)` is not implemented." + " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" + " requesting this feature." + ) + elif self.precision in (16, "bf16"): + return TPUHalfPrecisionPlugin() if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) @@ -601,9 +613,6 @@ class AcceleratorConnector: if self.precision == 64: return DoublePrecisionPlugin() if self.precision in (16, "bf16"): - if self.use_tpu: - return TPUHalfPrecisionPlugin() - if self.amp_type == AMPType.NATIVE: if self.amp_level is not None: raise MisconfigurationException( diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 6f168b9275..64239399fd 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -976,3 +976,13 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock with pytest.raises(SystemExit): trainer.fit(model) + + +def test_unsupported_tpu_choice(monkeypatch): + import pytorch_lightning.utilities.imports as imports + from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector + + monkeypatch.setattr(imports, "_XLA_AVAILABLE", True) + monkeypatch.setattr(AcceleratorConnector, "has_tpu", True) + with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): + Trainer(accelerator="tpu", precision=64) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 25743d5b3b..62789d1a54 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -22,7 +22,7 @@ from torch import nn from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -284,3 +284,13 @@ def test_auto_parameters_tying_tpus_nested_module(tmpdir): trainer.fit(model) assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight)) + + +def test_tpu_invalid_raises(): + accelerator = TPUAccelerator(object(), TPUSpawnPlugin()) + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): + accelerator.setup(object()) + + accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): + accelerator.setup(object())