[TPU] Proper half-precision implementation for XLA (#18213)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-08-11 17:37:41 +02:00 committed by GitHub
parent b88b8b3937
commit 7fe8756917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 288 additions and 232 deletions

View File

@ -116,7 +116,7 @@ By default, TPU training will use 32-bit precision. To enable it, do
import lightning.pytorch as pl
my_model = MyLightningModule()
trainer = pl.Trainer(accelerator="tpu", precision="16-mixed")
trainer = pl.Trainer(accelerator="tpu", precision="16-true")
trainer.fit(my_model)
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.

View File

@ -16,7 +16,7 @@ from typing import Any
import lightning.fabric as fabric
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins.precision import XLABf16Precision, XLAPrecision
from lightning.fabric.plugins.precision import XLAPrecision
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_xla import SingleDeviceXLAStrategy
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
@ -28,6 +28,7 @@ def _patch_sys_modules() -> None:
sys.modules["lightning.fabric.accelerators.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu_bf16"] = self
sys.modules["lightning.fabric.plugins.precision.xlabf16"] = self
class SingleTPUStrategy(SingleDeviceXLAStrategy):
@ -72,20 +73,35 @@ class TPUPrecision(XLAPrecision):
rank_zero_deprecation(
"The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead."
)
super().__init__(*args, **kwargs)
super().__init__(precision="32-true")
class XLABf16Precision(XLAPrecision):
"""Legacy class.
Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `XLABf16Precision` class is deprecated. Use"
" `lightning.fabric.plugins.precision.XLAPrecision` instead."
)
super().__init__(precision="bf16-true")
class TPUBf16Precision(XLABf16Precision):
"""Legacy class.
Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead.
Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUBf16Precision` class is deprecated. Use"
" `lightning.fabric.plugins.precision.XLABf16Precision` instead."
" `lightning.fabric.plugins.precision.XLAPrecision` instead."
)
super().__init__(*args, **kwargs)
@ -97,6 +113,8 @@ def _patch_classes() -> None:
setattr(fabric.plugins.precision, "TPUPrecision", TPUPrecision)
setattr(fabric.plugins, "TPUBf16Precision", TPUBf16Precision)
setattr(fabric.plugins.precision, "TPUBf16Precision", TPUBf16Precision)
setattr(fabric.plugins, "XLABf16Precision", XLABf16Precision)
setattr(fabric.plugins.precision, "XLABf16Precision", XLABf16Precision)
_patch_sys_modules()

View File

@ -29,7 +29,6 @@ from lightning.fabric.plugins import (
HalfPrecision,
MixedPrecision,
Precision,
XLABf16Precision,
XLAPrecision,
)
from lightning.fabric.plugins.environments import (
@ -432,18 +431,8 @@ class _Connector:
self._validate_precision_choice()
if isinstance(self._precision_instance, Precision):
return self._precision_instance
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_input == "32-true":
return XLAPrecision()
if self._precision_input in ("16-mixed", "bf16-mixed"):
if self._precision_input == "16-mixed":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
)
return XLABf16Precision()
return XLAPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
@ -477,18 +466,15 @@ class _Connector:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_input == "64-true":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
if self._precision_instance and not isinstance(self._precision_instance, (XLAPrecision, XLABf16Precision)):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
f" found: {self._precision_instance}."
)
if (
isinstance(self.accelerator, XLAAccelerator)
and self._precision_instance
and not isinstance(self._precision_instance, XLAPrecision)
):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
f" found: {self._precision_instance}."
)
def _lazy_init_strategy(self) -> None:
"""Lazily set missing attributes on the previously instantiated strategy."""

View File

@ -22,7 +22,6 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.half import HalfPrecision
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.xla import XLAPrecision
from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision
__all__ = [
"ClusterEnvironment",
@ -35,6 +34,5 @@ __all__ = [
"HalfPrecision",
"MixedPrecision",
"XLAPrecision",
"XLABf16Precision",
"FSDPPrecision",
]

View File

@ -19,7 +19,6 @@ from lightning.fabric.plugins.precision.half import HalfPrecision
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision
from lightning.fabric.plugins.precision.xla import XLAPrecision
from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision
__all__ = [
"DeepSpeedPrecision",
@ -28,7 +27,6 @@ __all__ = [
"MixedPrecision",
"Precision",
"XLAPrecision",
"XLABf16Precision",
"FSDPPrecision",
"TransformerEnginePrecision",
]

View File

@ -11,20 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import os
from typing import Any, Literal
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from typing_extensions import get_args
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.types import Optimizable
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"]
class XLAPrecision(Precision):
"""Precision plugin with XLA."""
"""Plugin for training with XLA.
def __init__(self, *args: Any, **kwargs: Any) -> None:
Args:
precision: Full precision (32-true) or half precision (16-true, bf16-true).
Raises:
ValueError:
If unsupported ``precision`` is provided.
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in XLA."
f" `precision` must be one of: {supported_precision}."
)
self.precision = precision
if precision == "16-true":
os.environ["XLA_USE_F16"] = "1"
self._desired_dtype = torch.float16
elif precision == "bf16-true":
os.environ["XLA_USE_BF16"] = "1"
self._desired_dtype = torch.bfloat16
else:
self._desired_dtype = torch.float32
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
def optimizer_step(
self,
@ -35,3 +74,7 @@ class XLAPrecision(Precision):
# you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
os.environ.pop("XLA_USE_F16", None)

View File

@ -1,41 +0,0 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Literal
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from lightning.fabric.plugins.precision import XLAPrecision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
class XLABf16Precision(XLAPrecision):
"""Plugin that enables mixed bf16 with XLA."""
precision: Literal["bf16-mixed"] = "bf16-mixed"
def __init__(self) -> None:
super().__init__()
os.environ["XLA_USE_BF16"] = "1"
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.bfloat16)
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)

View File

@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18219](https://github.com/Lightning-AI/lightning/pull/18219))
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18213](https://github.com/Lightning-AI/lightning/pull/18213), [#18219](https://github.com/Lightning-AI/lightning/pull/18219))
- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))

View File

@ -18,7 +18,7 @@ from typing import Any
import lightning.pytorch as pl
from lightning.fabric.strategies import _StrategyRegistry
from lightning.pytorch.accelerators.xla import XLAAccelerator
from lightning.pytorch.plugins.precision import XLABf16PrecisionPlugin, XLAPrecisionPlugin
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
@ -29,6 +29,7 @@ def _patch_sys_modules() -> None:
sys.modules["lightning.pytorch.accelerators.tpu"] = self
sys.modules["lightning.pytorch.plugins.precision.tpu"] = self
sys.modules["lightning.pytorch.plugins.precision.tpu_bf16"] = self
sys.modules["lightning.pytorch.plugins.precision.xlabf16"] = self
class SingleTPUStrategy(SingleDeviceXLAStrategy):
@ -74,22 +75,37 @@ class TPUPrecisionPlugin(XLAPrecisionPlugin):
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`"
" instead."
)
super().__init__(*args, **kwargs)
super().__init__(precision="32-true")
class TPUBf16PrecisionPlugin(XLABf16PrecisionPlugin):
class TPUBf16PrecisionPlugin(XLAPrecisionPlugin):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLABf16PrecisionPlugin` instead.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUBf16PrecisionPlugin` class is deprecated. Use"
" `lightning.pytorch.plugins.precision.XLABf16PrecisionPlugin` instead."
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
)
super().__init__(*args, **kwargs)
super().__init__(precision="bf16-true")
class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `XLABf16PrecisionPlugin` class is deprecated. Use"
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
)
super().__init__(precision="bf16-true")
def _patch_classes() -> None:
@ -99,6 +115,8 @@ def _patch_classes() -> None:
setattr(pl.plugins.precision, "TPUPrecisionPlugin", TPUPrecisionPlugin)
setattr(pl.plugins, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
setattr(pl.plugins.precision, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
setattr(pl.plugins, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
setattr(pl.plugins.precision, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
_patch_sys_modules()

View File

@ -10,7 +10,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN_INPUT = Union[PLUGIN, str]
@ -28,7 +27,6 @@ __all__ = [
"FSDPMixedPrecisionPlugin",
"FSDPPrecisionPlugin",
"XLAPrecisionPlugin",
"XLABf16PrecisionPlugin",
"LayerSync",
"TorchSyncBatchNorm",
]

View File

@ -18,7 +18,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
__all__ = [
"DeepSpeedPrecisionPlugin",
@ -29,5 +28,4 @@ __all__ = [
"MixedPrecisionPlugin",
"PrecisionPlugin",
"XLAPrecisionPlugin",
"XLABf16PrecisionPlugin",
]

View File

@ -11,30 +11,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import Any, Callable
import torch
from typing_extensions import get_args
import lightning.pytorch as pl
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities.exceptions import MisconfigurationException
class XLAPrecisionPlugin(PrecisionPlugin):
"""Precision plugin with XLA."""
"""Plugin for training with XLA.
def __init__(self, *args: Any, **kwargs: Any) -> None:
Args:
precision: Full precision (32-true) or half precision (16-true, bf16-true).
Raises:
ValueError:
If unsupported ``precision`` is provided.
"""
def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in XLA."
f" `precision` must be one of: {supported_precision}."
)
self.precision = precision
closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result
if precision == "16-true":
os.environ["XLA_USE_F16"] = "1"
self._desired_dtype = torch.float16
elif precision == "bf16-true":
os.environ["XLA_USE_BF16"] = "1"
self._desired_dtype = torch.bfloat16
else:
self._desired_dtype = torch.float32
def optimizer_step( # type: ignore[override]
self,
@ -45,7 +68,7 @@ class XLAPrecisionPlugin(PrecisionPlugin):
) -> Any:
import torch_xla.core.xla_model as xm
closure = partial(self._tpu_wrap_closure, optimizer, closure)
closure = partial(self._xla_wrap_closure, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, closure)
closure_result = optimizer.step(closure=closure, **kwargs)
xm.mark_step()
@ -59,3 +82,14 @@ class XLAPrecisionPlugin(PrecisionPlugin):
" requesting this feature."
)
return closure_result
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
os.environ.pop("XLA_USE_F16", None)
def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm
closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result

View File

@ -1,35 +0,0 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, List, Literal, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
"""Plugin that enables mixed bf16 with XLA."""
precision: Literal["bf16-mixed"] = "bf16-mixed"
def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
os.environ["XLA_USE_BF16"] = "1"
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)

View File

@ -44,7 +44,6 @@ from lightning.pytorch.plugins import (
MixedPrecisionPlugin,
PLUGIN_INPUT,
PrecisionPlugin,
XLABf16PrecisionPlugin,
XLAPrecisionPlugin,
)
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
@ -507,16 +506,6 @@ class _AcceleratorConnector:
if isinstance(self.accelerator, HPUAccelerator):
return HPUPrecisionPlugin(self._precision_flag)
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_flag == "32-true":
return XLAPrecisionPlugin()
if self._precision_flag in ("16-mixed", "bf16-mixed"):
if self._precision_flag == "16-mixed":
rank_zero_warn(
"You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported on TPUs. Using `precision='bf16-mixed'` instead."
)
return XLABf16PrecisionPlugin()
if _LIGHTNING_COLOSSALAI_AVAILABLE:
from lightning_colossalai import ColossalAIPrecisionPlugin, ColossalAIStrategy
@ -524,6 +513,8 @@ class _AcceleratorConnector:
if isinstance(self.strategy, ColossalAIStrategy):
return ColossalAIPrecisionPlugin(self._precision_flag)
if isinstance(self.accelerator, XLAAccelerator):
return XLAPrecisionPlugin(self._precision_flag) # type: ignore
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type]
if isinstance(self.strategy, FSDPStrategy):
@ -553,20 +544,15 @@ class _AcceleratorConnector:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, AMP type, and accelerator."""
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_flag == "64-true":
raise MisconfigurationException(
"`Trainer(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
if self._precision_plugin_flag and not isinstance(
self._precision_plugin_flag, (XLAPrecisionPlugin, XLABf16PrecisionPlugin)
):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
f" found: {self._precision_plugin_flag}."
)
if (
isinstance(self.accelerator, XLAAccelerator)
and self._precision_plugin_flag
and not isinstance(self._precision_plugin_flag, XLAPrecisionPlugin)
):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
f" found: {self._precision_plugin_flag}."
)
if _lightning_habana_available():
from lightning_habana import HPUAccelerator

View File

@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name):
("lightning.fabric.plugins", "TPUBf16Precision"),
("lightning.fabric.plugins.precision", "TPUBf16Precision"),
("lightning.fabric.plugins.precision.tpu_bf16", "TPUBf16Precision"),
("lightning.fabric.plugins.precision", "XLABf16Precision"),
("lightning.fabric.plugins.precision.xlabf16", "XLABf16Precision"),
],
)
def test_graveyard_no_device(import_path, name):

View File

@ -0,0 +1,63 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from unittest import mock
import pytest
import torch
from lightning.fabric.plugins import XLAPrecision
@mock.patch.dict(os.environ, {}, clear=True)
def test_precision_input_validation(xla_available):
XLAPrecision(precision="32-true")
XLAPrecision(precision="16-true")
XLAPrecision(precision="bf16-true")
with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")):
XLAPrecision("16")
with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")):
XLAPrecision("16-mixed")
with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")):
XLAPrecision("bf16-mixed")
with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")):
XLAPrecision("64-true")
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_selected_dtype(precision, expected_dtype, xla_available):
plugin = XLAPrecision(precision=precision)
assert plugin.precision == precision
assert plugin._desired_dtype == expected_dtype
def test_teardown(xla_available):
plugin = XLAPrecision(precision="16-true")
assert os.environ["XLA_USE_F16"] == "1"
plugin.teardown()
assert "XLA_USE_B16" not in os.environ
plugin = XLAPrecision(precision="bf16-true")
assert os.environ["XLA_USE_BF16"] == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ

View File

@ -1,23 +0,0 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from lightning.fabric.plugins import XLABf16Precision
def test_teardown(xla_available):
plugin = XLABf16Precision()
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ

View File

@ -655,16 +655,9 @@ def test_strategy_choice_ddp_cpu_slurm(strategy):
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_unsupported_tpu_choice(_, tpu_available):
with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"):
_Connector(accelerator="tpu", precision="64-true")
# if user didn't set strategy, _Connector will choose the TPUSingleStrategy or XLAStrategy
with pytest.raises(
ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"
), pytest.warns(
UserWarning, match=r"accelerator='tpu', precision='16-mixed'\)` but AMP with fp16 is not supported"
):
_Connector(accelerator="tpu", precision="16-mixed", strategy="ddp")
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
_Connector(accelerator="tpu", precision="16-true", strategy="ddp")
# wrong precision plugin type
strategy = XLAStrategy(accelerator=XLAAccelerator(), precision=Precision())
@ -672,7 +665,7 @@ def test_unsupported_tpu_choice(_, tpu_available):
_Connector(strategy=strategy)
# wrong strategy type
strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision())
strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision(precision="16-true"))
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
_Connector(strategy=strategy)
@ -1043,6 +1036,7 @@ def test_connector_auto_selection(monkeypatch, is_interactive):
assert connector.strategy.launcher.is_interactive_compatible
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_xla_fsdp_automatic_strategy_selection(monkeypatch, tpu_available):
import lightning.fabric.strategies as strategies

View File

@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name):
("lightning.pytorch.plugins", "TPUBf16PrecisionPlugin"),
("lightning.pytorch.plugins.precision", "TPUBf16PrecisionPlugin"),
("lightning.pytorch.plugins.precision.tpu_bf16", "TPUBf16PrecisionPlugin"),
("lightning.pytorch.plugins.precision", "XLABf16PrecisionPlugin"),
("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"),
],
)
def test_graveyard_no_device(import_path, name):

View File

@ -102,7 +102,7 @@ def test_model_multiple_tpu_devices(tmpdir):
def test_model_16bit_tpu_devices_1(tmpdir):
trainer_options = {
"default_root_dir": tmpdir,
"precision": "16-mixed",
"precision": "16-true",
"enable_progress_bar": False,
"max_epochs": 2,
"accelerator": "tpu",
@ -121,7 +121,7 @@ def test_model_16bit_tpu_devices_1(tmpdir):
def test_model_16bit_tpu_index(tmpdir, tpu_core):
trainer_options = {
"default_root_dir": tmpdir,
"precision": "16-mixed",
"precision": "16-true",
"enable_progress_bar": False,
"max_epochs": 2,
"accelerator": "tpu",
@ -143,7 +143,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
def test_model_16bit_multiple_tpu_devices(tmpdir):
trainer_options = {
"default_root_dir": tmpdir,
"precision": "16-mixed",
"precision": "16-true",
"enable_progress_bar": False,
"max_epochs": 1,
"accelerator": "tpu",

View File

@ -11,18 +11,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from lightning.pytorch.plugins import XLAPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf
@RunIf(tpu=True)
@mock.patch.dict(os.environ, {}, clear=True)
def test_optimizer_step_calls_mark_step():
plugin = XLAPrecisionPlugin()
plugin = XLAPrecisionPlugin(precision="32-true")
optimizer = Mock()
with mock.patch("torch_xla.core.xla_model") as xm_mock:
plugin.optimizer_step(optimizer=optimizer, model=Mock(), closure=Mock())
optimizer.step.assert_called_once()
xm_mock.mark_step.assert_called_once()
@mock.patch.dict(os.environ, {}, clear=True)
def test_precision_input_validation(xla_available):
XLAPrecisionPlugin(precision="32-true")
XLAPrecisionPlugin(precision="16-true")
XLAPrecisionPlugin(precision="bf16-true")
with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")):
XLAPrecisionPlugin("16")
with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")):
XLAPrecisionPlugin("16-mixed")
with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")):
XLAPrecisionPlugin("bf16-mixed")
with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")):
XLAPrecisionPlugin("64-true")
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_selected_dtype(precision, expected_dtype, xla_available):
plugin = XLAPrecisionPlugin(precision=precision)
assert plugin.precision == precision
assert plugin._desired_dtype == expected_dtype
def test_teardown(xla_available):
plugin = XLAPrecisionPlugin(precision="16-true")
assert os.environ["XLA_USE_F16"] == "1"
plugin.teardown()
assert "XLA_USE_B16" not in os.environ
plugin = XLAPrecisionPlugin(precision="bf16-true")
assert os.environ["XLA_USE_BF16"] == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ

View File

@ -1,25 +0,0 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock
from lightning.pytorch.plugins import XLABf16PrecisionPlugin
def test_teardown(xla_available):
plugin = XLABf16PrecisionPlugin()
plugin.connect(Mock(), Mock(), Mock())
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ

View File

@ -549,17 +549,11 @@ def test_check_fsdp_strategy_and_fallback():
Trainer(accelerator="cpu", strategy="fsdp")
def test_unsupported_tpu_choice(tpu_available):
with pytest.raises(
MisconfigurationException, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"
):
Trainer(accelerator="tpu", precision="64-true")
# if user didn't set strategy, AcceleratorConnector will choose "single_xla" or "xla"
with pytest.raises(
ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"
), pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16-mixed\)` but AMP with fp16 is not supported"):
Trainer(accelerator="tpu", precision="16-mixed", strategy="ddp")
@mock.patch.dict(os.environ, {}, clear=True)
def test_unsupported_tpu_choice(xla_available, tpu_available):
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
Trainer(accelerator="tpu", precision="16-true", strategy="ddp")
def mock_ipu_available(monkeypatch, value=True):