From 7fe8756917cdf7b5842d61cd5e1889dc007fb629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Aug 2023 17:37:41 +0200 Subject: [PATCH] [TPU] Proper half-precision implementation for XLA (#18213) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../accelerators/tpu_intermediate.rst | 2 +- src/lightning/fabric/_graveyard/tpu.py | 26 ++++++-- src/lightning/fabric/connector.py | 34 +++------- src/lightning/fabric/plugins/__init__.py | 2 - .../fabric/plugins/precision/__init__.py | 2 - src/lightning/fabric/plugins/precision/xla.py | 51 +++++++++++++-- .../fabric/plugins/precision/xlabf16.py | 41 ------------ src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/_graveyard/tpu.py | 30 +++++++-- src/lightning/pytorch/plugins/__init__.py | 2 - .../pytorch/plugins/precision/__init__.py | 2 - .../pytorch/plugins/precision/xla.py | 52 ++++++++++++--- .../pytorch/plugins/precision/xlabf16.py | 35 ----------- .../connectors/accelerator_connector.py | 36 ++++------- tests/tests_fabric/graveyard/test_tpu.py | 2 + .../plugins/precision/test_xla.py | 63 +++++++++++++++++++ .../plugins/precision/test_xlabf16.py | 23 ------- tests/tests_fabric/test_connector.py | 16 ++--- tests/tests_pytorch/graveyard/test_tpu.py | 2 + tests/tests_pytorch/models/test_tpu.py | 6 +- .../plugins/precision/test_xla.py | 50 ++++++++++++++- .../plugins/precision/test_xlabf16.py | 25 -------- .../connectors/test_accelerator_connector.py | 16 ++--- 23 files changed, 288 insertions(+), 232 deletions(-) delete mode 100644 src/lightning/fabric/plugins/precision/xlabf16.py delete mode 100644 src/lightning/pytorch/plugins/precision/xlabf16.py create mode 100644 tests/tests_fabric/plugins/precision/test_xla.py delete mode 100644 tests/tests_fabric/plugins/precision/test_xlabf16.py delete mode 100644 tests/tests_pytorch/plugins/precision/test_xlabf16.py diff --git a/docs/source-pytorch/accelerators/tpu_intermediate.rst b/docs/source-pytorch/accelerators/tpu_intermediate.rst index 579d2cc5cf..8dfe63f336 100644 --- a/docs/source-pytorch/accelerators/tpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/tpu_intermediate.rst @@ -116,7 +116,7 @@ By default, TPU training will use 32-bit precision. To enable it, do import lightning.pytorch as pl my_model = MyLightningModule() - trainer = pl.Trainer(accelerator="tpu", precision="16-mixed") + trainer = pl.Trainer(accelerator="tpu", precision="16-true") trainer.fit(my_model) Under the hood the xla library will use the `bfloat16 type `_. diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py index b38cfc4708..c537ffc032 100644 --- a/src/lightning/fabric/_graveyard/tpu.py +++ b/src/lightning/fabric/_graveyard/tpu.py @@ -16,7 +16,7 @@ from typing import Any import lightning.fabric as fabric from lightning.fabric.accelerators import XLAAccelerator -from lightning.fabric.plugins.precision import XLABf16Precision, XLAPrecision +from lightning.fabric.plugins.precision import XLAPrecision from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.single_xla import SingleDeviceXLAStrategy from lightning.fabric.utilities.rank_zero import rank_zero_deprecation @@ -28,6 +28,7 @@ def _patch_sys_modules() -> None: sys.modules["lightning.fabric.accelerators.tpu"] = self sys.modules["lightning.fabric.plugins.precision.tpu"] = self sys.modules["lightning.fabric.plugins.precision.tpu_bf16"] = self + sys.modules["lightning.fabric.plugins.precision.xlabf16"] = self class SingleTPUStrategy(SingleDeviceXLAStrategy): @@ -72,20 +73,35 @@ class TPUPrecision(XLAPrecision): rank_zero_deprecation( "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead." ) - super().__init__(*args, **kwargs) + super().__init__(precision="32-true") + + +class XLABf16Precision(XLAPrecision): + """Legacy class. + + Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "The `XLABf16Precision` class is deprecated. Use" + " `lightning.fabric.plugins.precision.XLAPrecision` instead." + ) + super().__init__(precision="bf16-true") class TPUBf16Precision(XLABf16Precision): """Legacy class. - Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead. + Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead. """ def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( "The `TPUBf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLABf16Precision` instead." + " `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(*args, **kwargs) @@ -97,6 +113,8 @@ def _patch_classes() -> None: setattr(fabric.plugins.precision, "TPUPrecision", TPUPrecision) setattr(fabric.plugins, "TPUBf16Precision", TPUBf16Precision) setattr(fabric.plugins.precision, "TPUBf16Precision", TPUBf16Precision) + setattr(fabric.plugins, "XLABf16Precision", XLABf16Precision) + setattr(fabric.plugins.precision, "XLABf16Precision", XLABf16Precision) _patch_sys_modules() diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index d06750b666..bb752df762 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -29,7 +29,6 @@ from lightning.fabric.plugins import ( HalfPrecision, MixedPrecision, Precision, - XLABf16Precision, XLAPrecision, ) from lightning.fabric.plugins.environments import ( @@ -432,18 +431,8 @@ class _Connector: self._validate_precision_choice() if isinstance(self._precision_instance, Precision): return self._precision_instance - if isinstance(self.accelerator, XLAAccelerator): - if self._precision_input == "32-true": - return XLAPrecision() - if self._precision_input in ("16-mixed", "bf16-mixed"): - if self._precision_input == "16-mixed": - rank_zero_warn( - "You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16" - " is not supported with TPUs. Using `precision='bf16-mixed'` instead." - ) - return XLABf16Precision() - + return XLAPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): @@ -477,18 +466,15 @@ class _Connector: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, and accelerator.""" - if isinstance(self.accelerator, XLAAccelerator): - if self._precision_input == "64-true": - raise NotImplementedError( - "`Fabric(accelerator='tpu', precision='64-true')` is not implemented." - " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" - " requesting this feature." - ) - if self._precision_instance and not isinstance(self._precision_instance, (XLAPrecision, XLABf16Precision)): - raise ValueError( - f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin," - f" found: {self._precision_instance}." - ) + if ( + isinstance(self.accelerator, XLAAccelerator) + and self._precision_instance + and not isinstance(self._precision_instance, XLAPrecision) + ): + raise ValueError( + f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin," + f" found: {self._precision_instance}." + ) def _lazy_init_strategy(self) -> None: """Lazily set missing attributes on the previously instantiated strategy.""" diff --git a/src/lightning/fabric/plugins/__init__.py b/src/lightning/fabric/plugins/__init__.py index 30a11b8c11..c28e904923 100644 --- a/src/lightning/fabric/plugins/__init__.py +++ b/src/lightning/fabric/plugins/__init__.py @@ -22,7 +22,6 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision from lightning.fabric.plugins.precision.half import HalfPrecision from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.xla import XLAPrecision -from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision __all__ = [ "ClusterEnvironment", @@ -35,6 +34,5 @@ __all__ = [ "HalfPrecision", "MixedPrecision", "XLAPrecision", - "XLABf16Precision", "FSDPPrecision", ] diff --git a/src/lightning/fabric/plugins/precision/__init__.py b/src/lightning/fabric/plugins/precision/__init__.py index b5b1ca0ef0..1aecb6a32e 100644 --- a/src/lightning/fabric/plugins/precision/__init__.py +++ b/src/lightning/fabric/plugins/precision/__init__.py @@ -19,7 +19,6 @@ from lightning.fabric.plugins.precision.half import HalfPrecision from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision from lightning.fabric.plugins.precision.xla import XLAPrecision -from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision __all__ = [ "DeepSpeedPrecision", @@ -28,7 +27,6 @@ __all__ = [ "MixedPrecision", "Precision", "XLAPrecision", - "XLABf16Precision", "FSDPPrecision", "TransformerEnginePrecision", ] diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py index 20e1f44880..e837828fae 100644 --- a/src/lightning/fabric/plugins/precision/xla.py +++ b/src/lightning/fabric/plugins/precision/xla.py @@ -11,20 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +import os +from typing import Any, Literal + +import torch +from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor +from typing_extensions import get_args from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor from lightning.fabric.utilities.types import Optimizable +_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"] + class XLAPrecision(Precision): - """Precision plugin with XLA.""" + """Plugin for training with XLA. - def __init__(self, *args: Any, **kwargs: Any) -> None: + Args: + precision: Full precision (32-true) or half precision (16-true, bf16-true). + + Raises: + ValueError: + If unsupported ``precision`` is provided. + + """ + + def __init__(self, precision: _PRECISION_INPUT) -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - super().__init__(*args, **kwargs) + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in XLA." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision + + if precision == "16-true": + os.environ["XLA_USE_F16"] = "1" + self._desired_dtype = torch.float16 + elif precision == "bf16-true": + os.environ["XLA_USE_BF16"] = "1" + self._desired_dtype = torch.bfloat16 + else: + self._desired_dtype = torch.float32 + + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) + + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) def optimizer_step( self, @@ -35,3 +74,7 @@ class XLAPrecision(Precision): # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True` return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True) + + def teardown(self) -> None: + os.environ.pop("XLA_USE_BF16", None) + os.environ.pop("XLA_USE_F16", None) diff --git a/src/lightning/fabric/plugins/precision/xlabf16.py b/src/lightning/fabric/plugins/precision/xlabf16.py deleted file mode 100644 index 34e5d4cc01..0000000000 --- a/src/lightning/fabric/plugins/precision/xlabf16.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import Any, Literal - -import torch -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor - -from lightning.fabric.plugins.precision import XLAPrecision -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor - - -class XLABf16Precision(XLAPrecision): - """Plugin that enables mixed bf16 with XLA.""" - - precision: Literal["bf16-mixed"] = "bf16-mixed" - - def __init__(self) -> None: - super().__init__() - os.environ["XLA_USE_BF16"] = "1" - - def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.bfloat16) - - def convert_output(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) - - def teardown(self) -> None: - os.environ.pop("XLA_USE_BF16", None) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9655d7b59e..a53c99e647 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194)) -- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18219](https://github.com/Lightning-AI/lightning/pull/18219)) +- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18213](https://github.com/Lightning-AI/lightning/pull/18213), [#18219](https://github.com/Lightning-AI/lightning/pull/18219)) - Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218)) diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py index dde1729735..eb1d493645 100644 --- a/src/lightning/pytorch/_graveyard/tpu.py +++ b/src/lightning/pytorch/_graveyard/tpu.py @@ -18,7 +18,7 @@ from typing import Any import lightning.pytorch as pl from lightning.fabric.strategies import _StrategyRegistry from lightning.pytorch.accelerators.xla import XLAAccelerator -from lightning.pytorch.plugins.precision import XLABf16PrecisionPlugin, XLAPrecisionPlugin +from lightning.pytorch.plugins.precision import XLAPrecisionPlugin from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation @@ -29,6 +29,7 @@ def _patch_sys_modules() -> None: sys.modules["lightning.pytorch.accelerators.tpu"] = self sys.modules["lightning.pytorch.plugins.precision.tpu"] = self sys.modules["lightning.pytorch.plugins.precision.tpu_bf16"] = self + sys.modules["lightning.pytorch.plugins.precision.xlabf16"] = self class SingleTPUStrategy(SingleDeviceXLAStrategy): @@ -74,22 +75,37 @@ class TPUPrecisionPlugin(XLAPrecisionPlugin): "The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`" " instead." ) - super().__init__(*args, **kwargs) + super().__init__(precision="32-true") -class TPUBf16PrecisionPlugin(XLABf16PrecisionPlugin): +class TPUBf16PrecisionPlugin(XLAPrecisionPlugin): """Legacy class. - Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLABf16PrecisionPlugin` instead. + Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead. """ def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( "The `TPUBf16PrecisionPlugin` class is deprecated. Use" - " `lightning.pytorch.plugins.precision.XLABf16PrecisionPlugin` instead." + " `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead." ) - super().__init__(*args, **kwargs) + super().__init__(precision="bf16-true") + + +class XLABf16PrecisionPlugin(XLAPrecisionPlugin): + """Legacy class. + + Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "The `XLABf16PrecisionPlugin` class is deprecated. Use" + " `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead." + ) + super().__init__(precision="bf16-true") def _patch_classes() -> None: @@ -99,6 +115,8 @@ def _patch_classes() -> None: setattr(pl.plugins.precision, "TPUPrecisionPlugin", TPUPrecisionPlugin) setattr(pl.plugins, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin) setattr(pl.plugins.precision, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin) + setattr(pl.plugins, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin) + setattr(pl.plugins.precision, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin) _patch_sys_modules() diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index 05280e3585..2848cf205f 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -10,7 +10,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin -from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync] PLUGIN_INPUT = Union[PLUGIN, str] @@ -28,7 +27,6 @@ __all__ = [ "FSDPMixedPrecisionPlugin", "FSDPPrecisionPlugin", "XLAPrecisionPlugin", - "XLABf16PrecisionPlugin", "LayerSync", "TorchSyncBatchNorm", ] diff --git a/src/lightning/pytorch/plugins/precision/__init__.py b/src/lightning/pytorch/plugins/precision/__init__.py index 95e537494b..64a58337f5 100644 --- a/src/lightning/pytorch/plugins/precision/__init__.py +++ b/src/lightning/pytorch/plugins/precision/__init__.py @@ -18,7 +18,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin -from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin __all__ = [ "DeepSpeedPrecisionPlugin", @@ -29,5 +28,4 @@ __all__ = [ "MixedPrecisionPlugin", "PrecisionPlugin", "XLAPrecisionPlugin", - "XLABf16PrecisionPlugin", ] diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index c8bae9cc68..00c3db5f90 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -11,30 +11,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from functools import partial from typing import Any, Callable +import torch +from typing_extensions import get_args + import lightning.pytorch as pl from lightning.fabric.accelerators.xla import _XLA_AVAILABLE +from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.utilities.exceptions import MisconfigurationException class XLAPrecisionPlugin(PrecisionPlugin): - """Precision plugin with XLA.""" + """Plugin for training with XLA. - def __init__(self, *args: Any, **kwargs: Any) -> None: + Args: + precision: Full precision (32-true) or half precision (16-true, bf16-true). + + Raises: + ValueError: + If unsupported ``precision`` is provided. + + """ + + def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - super().__init__(*args, **kwargs) - def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any: - import torch_xla.core.xla_model as xm + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in XLA." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision - closure_result = closure() - xm.reduce_gradients(optimizer) - return closure_result + if precision == "16-true": + os.environ["XLA_USE_F16"] = "1" + self._desired_dtype = torch.float16 + elif precision == "bf16-true": + os.environ["XLA_USE_BF16"] = "1" + self._desired_dtype = torch.bfloat16 + else: + self._desired_dtype = torch.float32 def optimizer_step( # type: ignore[override] self, @@ -45,7 +68,7 @@ class XLAPrecisionPlugin(PrecisionPlugin): ) -> Any: import torch_xla.core.xla_model as xm - closure = partial(self._tpu_wrap_closure, optimizer, closure) + closure = partial(self._xla_wrap_closure, optimizer, closure) closure = partial(self._wrap_closure, model, optimizer, closure) closure_result = optimizer.step(closure=closure, **kwargs) xm.mark_step() @@ -59,3 +82,14 @@ class XLAPrecisionPlugin(PrecisionPlugin): " requesting this feature." ) return closure_result + + def teardown(self) -> None: + os.environ.pop("XLA_USE_BF16", None) + os.environ.pop("XLA_USE_F16", None) + + def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any: + import torch_xla.core.xla_model as xm + + closure_result = closure() + xm.reduce_gradients(optimizer) + return closure_result diff --git a/src/lightning/pytorch/plugins/precision/xlabf16.py b/src/lightning/pytorch/plugins/precision/xlabf16.py deleted file mode 100644 index e9c0bdd1fe..0000000000 --- a/src/lightning/pytorch/plugins/precision/xlabf16.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import Any, List, Literal, Tuple - -import torch.nn as nn -from torch.optim import Optimizer - -from lightning.pytorch.plugins.precision import XLAPrecisionPlugin - - -class XLABf16PrecisionPlugin(XLAPrecisionPlugin): - """Plugin that enables mixed bf16 with XLA.""" - - precision: Literal["bf16-mixed"] = "bf16-mixed" - - def connect( - self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: - os.environ["XLA_USE_BF16"] = "1" - return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) - - def teardown(self) -> None: - os.environ.pop("XLA_USE_BF16", None) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index a67a709c7f..90edc09d3b 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -44,7 +44,6 @@ from lightning.pytorch.plugins import ( MixedPrecisionPlugin, PLUGIN_INPUT, PrecisionPlugin, - XLABf16PrecisionPlugin, XLAPrecisionPlugin, ) from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm @@ -507,16 +506,6 @@ class _AcceleratorConnector: if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) - if isinstance(self.accelerator, XLAAccelerator): - if self._precision_flag == "32-true": - return XLAPrecisionPlugin() - if self._precision_flag in ("16-mixed", "bf16-mixed"): - if self._precision_flag == "16-mixed": - rank_zero_warn( - "You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16" - " is not supported on TPUs. Using `precision='bf16-mixed'` instead." - ) - return XLABf16PrecisionPlugin() if _LIGHTNING_COLOSSALAI_AVAILABLE: from lightning_colossalai import ColossalAIPrecisionPlugin, ColossalAIStrategy @@ -524,6 +513,8 @@ class _AcceleratorConnector: if isinstance(self.strategy, ColossalAIStrategy): return ColossalAIPrecisionPlugin(self._precision_flag) + if isinstance(self.accelerator, XLAAccelerator): + return XLAPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): @@ -553,20 +544,15 @@ class _AcceleratorConnector: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, AMP type, and accelerator.""" - if isinstance(self.accelerator, XLAAccelerator): - if self._precision_flag == "64-true": - raise MisconfigurationException( - "`Trainer(accelerator='tpu', precision='64-true')` is not implemented." - " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" - " requesting this feature." - ) - if self._precision_plugin_flag and not isinstance( - self._precision_plugin_flag, (XLAPrecisionPlugin, XLABf16PrecisionPlugin) - ): - raise ValueError( - f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`," - f" found: {self._precision_plugin_flag}." - ) + if ( + isinstance(self.accelerator, XLAAccelerator) + and self._precision_plugin_flag + and not isinstance(self._precision_plugin_flag, XLAPrecisionPlugin) + ): + raise ValueError( + f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`," + f" found: {self._precision_plugin_flag}." + ) if _lightning_habana_available(): from lightning_habana import HPUAccelerator diff --git a/tests/tests_fabric/graveyard/test_tpu.py b/tests/tests_fabric/graveyard/test_tpu.py index 1c65f8604f..5b72d60491 100644 --- a/tests/tests_fabric/graveyard/test_tpu.py +++ b/tests/tests_fabric/graveyard/test_tpu.py @@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name): ("lightning.fabric.plugins", "TPUBf16Precision"), ("lightning.fabric.plugins.precision", "TPUBf16Precision"), ("lightning.fabric.plugins.precision.tpu_bf16", "TPUBf16Precision"), + ("lightning.fabric.plugins.precision", "XLABf16Precision"), + ("lightning.fabric.plugins.precision.xlabf16", "XLABf16Precision"), ], ) def test_graveyard_no_device(import_path, name): diff --git a/tests/tests_fabric/plugins/precision/test_xla.py b/tests/tests_fabric/plugins/precision/test_xla.py new file mode 100644 index 0000000000..cfdc32112a --- /dev/null +++ b/tests/tests_fabric/plugins/precision/test_xla.py @@ -0,0 +1,63 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from unittest import mock + +import pytest +import torch + +from lightning.fabric.plugins import XLAPrecision + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_precision_input_validation(xla_available): + XLAPrecision(precision="32-true") + XLAPrecision(precision="16-true") + XLAPrecision(precision="bf16-true") + + with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")): + XLAPrecision("16") + with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")): + XLAPrecision("16-mixed") + with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")): + XLAPrecision("bf16-mixed") + with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")): + XLAPrecision("64-true") + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_selected_dtype(precision, expected_dtype, xla_available): + plugin = XLAPrecision(precision=precision) + assert plugin.precision == precision + assert plugin._desired_dtype == expected_dtype + + +def test_teardown(xla_available): + plugin = XLAPrecision(precision="16-true") + assert os.environ["XLA_USE_F16"] == "1" + plugin.teardown() + assert "XLA_USE_B16" not in os.environ + + plugin = XLAPrecision(precision="bf16-true") + assert os.environ["XLA_USE_BF16"] == "1" + plugin.teardown() + assert "XLA_USE_BF16" not in os.environ diff --git a/tests/tests_fabric/plugins/precision/test_xlabf16.py b/tests/tests_fabric/plugins/precision/test_xlabf16.py deleted file mode 100644 index 0b3362471d..0000000000 --- a/tests/tests_fabric/plugins/precision/test_xlabf16.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from lightning.fabric.plugins import XLABf16Precision - - -def test_teardown(xla_available): - plugin = XLABf16Precision() - assert os.environ.get("XLA_USE_BF16") == "1" - plugin.teardown() - assert "XLA_USE_BF16" not in os.environ diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index f6a333897e..184df7b710 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -655,16 +655,9 @@ def test_strategy_choice_ddp_cpu_slurm(strategy): @mock.patch.dict(os.environ, {}, clear=True) @mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) def test_unsupported_tpu_choice(_, tpu_available): - with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"): - _Connector(accelerator="tpu", precision="64-true") - - # if user didn't set strategy, _Connector will choose the TPUSingleStrategy or XLAStrategy - with pytest.raises( - ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`" - ), pytest.warns( - UserWarning, match=r"accelerator='tpu', precision='16-mixed'\)` but AMP with fp16 is not supported" - ): - _Connector(accelerator="tpu", precision="16-mixed", strategy="ddp") + # if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy + with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"): + _Connector(accelerator="tpu", precision="16-true", strategy="ddp") # wrong precision plugin type strategy = XLAStrategy(accelerator=XLAAccelerator(), precision=Precision()) @@ -672,7 +665,7 @@ def test_unsupported_tpu_choice(_, tpu_available): _Connector(strategy=strategy) # wrong strategy type - strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision()) + strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision(precision="16-true")) with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"): _Connector(strategy=strategy) @@ -1043,6 +1036,7 @@ def test_connector_auto_selection(monkeypatch, is_interactive): assert connector.strategy.launcher.is_interactive_compatible +@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) def test_xla_fsdp_automatic_strategy_selection(monkeypatch, tpu_available): import lightning.fabric.strategies as strategies diff --git a/tests/tests_pytorch/graveyard/test_tpu.py b/tests/tests_pytorch/graveyard/test_tpu.py index 8ee5df717e..0e010ca53d 100644 --- a/tests/tests_pytorch/graveyard/test_tpu.py +++ b/tests/tests_pytorch/graveyard/test_tpu.py @@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name): ("lightning.pytorch.plugins", "TPUBf16PrecisionPlugin"), ("lightning.pytorch.plugins.precision", "TPUBf16PrecisionPlugin"), ("lightning.pytorch.plugins.precision.tpu_bf16", "TPUBf16PrecisionPlugin"), + ("lightning.pytorch.plugins.precision", "XLABf16PrecisionPlugin"), + ("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"), ], ) def test_graveyard_no_device(import_path, name): diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 13c7bf923a..96bd3eef5c 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -102,7 +102,7 @@ def test_model_multiple_tpu_devices(tmpdir): def test_model_16bit_tpu_devices_1(tmpdir): trainer_options = { "default_root_dir": tmpdir, - "precision": "16-mixed", + "precision": "16-true", "enable_progress_bar": False, "max_epochs": 2, "accelerator": "tpu", @@ -121,7 +121,7 @@ def test_model_16bit_tpu_devices_1(tmpdir): def test_model_16bit_tpu_index(tmpdir, tpu_core): trainer_options = { "default_root_dir": tmpdir, - "precision": "16-mixed", + "precision": "16-true", "enable_progress_bar": False, "max_epochs": 2, "accelerator": "tpu", @@ -143,7 +143,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): def test_model_16bit_multiple_tpu_devices(tmpdir): trainer_options = { "default_root_dir": tmpdir, - "precision": "16-mixed", + "precision": "16-true", "enable_progress_bar": False, "max_epochs": 1, "accelerator": "tpu", diff --git a/tests/tests_pytorch/plugins/precision/test_xla.py b/tests/tests_pytorch/plugins/precision/test_xla.py index 110bcb8c62..e99a12ca70 100644 --- a/tests/tests_pytorch/plugins/precision/test_xla.py +++ b/tests/tests_pytorch/plugins/precision/test_xla.py @@ -11,18 +11,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import re from unittest import mock from unittest.mock import Mock +import pytest +import torch + from lightning.pytorch.plugins import XLAPrecisionPlugin from tests_pytorch.helpers.runif import RunIf @RunIf(tpu=True) +@mock.patch.dict(os.environ, {}, clear=True) def test_optimizer_step_calls_mark_step(): - plugin = XLAPrecisionPlugin() + plugin = XLAPrecisionPlugin(precision="32-true") optimizer = Mock() with mock.patch("torch_xla.core.xla_model") as xm_mock: plugin.optimizer_step(optimizer=optimizer, model=Mock(), closure=Mock()) optimizer.step.assert_called_once() xm_mock.mark_step.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_precision_input_validation(xla_available): + XLAPrecisionPlugin(precision="32-true") + XLAPrecisionPlugin(precision="16-true") + XLAPrecisionPlugin(precision="bf16-true") + + with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")): + XLAPrecisionPlugin("16") + with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")): + XLAPrecisionPlugin("16-mixed") + with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")): + XLAPrecisionPlugin("bf16-mixed") + with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")): + XLAPrecisionPlugin("64-true") + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_selected_dtype(precision, expected_dtype, xla_available): + plugin = XLAPrecisionPlugin(precision=precision) + assert plugin.precision == precision + assert plugin._desired_dtype == expected_dtype + + +def test_teardown(xla_available): + plugin = XLAPrecisionPlugin(precision="16-true") + assert os.environ["XLA_USE_F16"] == "1" + plugin.teardown() + assert "XLA_USE_B16" not in os.environ + + plugin = XLAPrecisionPlugin(precision="bf16-true") + assert os.environ["XLA_USE_BF16"] == "1" + plugin.teardown() + assert "XLA_USE_BF16" not in os.environ diff --git a/tests/tests_pytorch/plugins/precision/test_xlabf16.py b/tests/tests_pytorch/plugins/precision/test_xlabf16.py deleted file mode 100644 index bfc89d6758..0000000000 --- a/tests/tests_pytorch/plugins/precision/test_xlabf16.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from unittest.mock import Mock - -from lightning.pytorch.plugins import XLABf16PrecisionPlugin - - -def test_teardown(xla_available): - plugin = XLABf16PrecisionPlugin() - plugin.connect(Mock(), Mock(), Mock()) - assert os.environ.get("XLA_USE_BF16") == "1" - plugin.teardown() - assert "XLA_USE_BF16" not in os.environ diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 2821f70275..48c3cb3571 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -549,17 +549,11 @@ def test_check_fsdp_strategy_and_fallback(): Trainer(accelerator="cpu", strategy="fsdp") -def test_unsupported_tpu_choice(tpu_available): - with pytest.raises( - MisconfigurationException, match=r"accelerator='tpu', precision='64-true'\)` is not implemented" - ): - Trainer(accelerator="tpu", precision="64-true") - - # if user didn't set strategy, AcceleratorConnector will choose "single_xla" or "xla" - with pytest.raises( - ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`" - ), pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16-mixed\)` but AMP with fp16 is not supported"): - Trainer(accelerator="tpu", precision="16-mixed", strategy="ddp") +@mock.patch.dict(os.environ, {}, clear=True) +def test_unsupported_tpu_choice(xla_available, tpu_available): + # if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy + with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"): + Trainer(accelerator="tpu", precision="16-true", strategy="ddp") def mock_ipu_available(monkeypatch, value=True):