[TPU] Proper half-precision implementation for XLA (#18213)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b88b8b3937
commit
7fe8756917
|
@ -116,7 +116,7 @@ By default, TPU training will use 32-bit precision. To enable it, do
|
|||
import lightning.pytorch as pl
|
||||
|
||||
my_model = MyLightningModule()
|
||||
trainer = pl.Trainer(accelerator="tpu", precision="16-mixed")
|
||||
trainer = pl.Trainer(accelerator="tpu", precision="16-true")
|
||||
trainer.fit(my_model)
|
||||
|
||||
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
|
||||
|
|
|
@ -16,7 +16,7 @@ from typing import Any
|
|||
|
||||
import lightning.fabric as fabric
|
||||
from lightning.fabric.accelerators import XLAAccelerator
|
||||
from lightning.fabric.plugins.precision import XLABf16Precision, XLAPrecision
|
||||
from lightning.fabric.plugins.precision import XLAPrecision
|
||||
from lightning.fabric.strategies import _StrategyRegistry
|
||||
from lightning.fabric.strategies.single_xla import SingleDeviceXLAStrategy
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
|
||||
|
@ -28,6 +28,7 @@ def _patch_sys_modules() -> None:
|
|||
sys.modules["lightning.fabric.accelerators.tpu"] = self
|
||||
sys.modules["lightning.fabric.plugins.precision.tpu"] = self
|
||||
sys.modules["lightning.fabric.plugins.precision.tpu_bf16"] = self
|
||||
sys.modules["lightning.fabric.plugins.precision.xlabf16"] = self
|
||||
|
||||
|
||||
class SingleTPUStrategy(SingleDeviceXLAStrategy):
|
||||
|
@ -72,20 +73,35 @@ class TPUPrecision(XLAPrecision):
|
|||
rank_zero_deprecation(
|
||||
"The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(precision="32-true")
|
||||
|
||||
|
||||
class XLABf16Precision(XLAPrecision):
|
||||
"""Legacy class.
|
||||
|
||||
Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `XLABf16Precision` class is deprecated. Use"
|
||||
" `lightning.fabric.plugins.precision.XLAPrecision` instead."
|
||||
)
|
||||
super().__init__(precision="bf16-true")
|
||||
|
||||
|
||||
class TPUBf16Precision(XLABf16Precision):
|
||||
"""Legacy class.
|
||||
|
||||
Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead.
|
||||
Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `TPUBf16Precision` class is deprecated. Use"
|
||||
" `lightning.fabric.plugins.precision.XLABf16Precision` instead."
|
||||
" `lightning.fabric.plugins.precision.XLAPrecision` instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -97,6 +113,8 @@ def _patch_classes() -> None:
|
|||
setattr(fabric.plugins.precision, "TPUPrecision", TPUPrecision)
|
||||
setattr(fabric.plugins, "TPUBf16Precision", TPUBf16Precision)
|
||||
setattr(fabric.plugins.precision, "TPUBf16Precision", TPUBf16Precision)
|
||||
setattr(fabric.plugins, "XLABf16Precision", XLABf16Precision)
|
||||
setattr(fabric.plugins.precision, "XLABf16Precision", XLABf16Precision)
|
||||
|
||||
|
||||
_patch_sys_modules()
|
||||
|
|
|
@ -29,7 +29,6 @@ from lightning.fabric.plugins import (
|
|||
HalfPrecision,
|
||||
MixedPrecision,
|
||||
Precision,
|
||||
XLABf16Precision,
|
||||
XLAPrecision,
|
||||
)
|
||||
from lightning.fabric.plugins.environments import (
|
||||
|
@ -432,18 +431,8 @@ class _Connector:
|
|||
self._validate_precision_choice()
|
||||
if isinstance(self._precision_instance, Precision):
|
||||
return self._precision_instance
|
||||
|
||||
if isinstance(self.accelerator, XLAAccelerator):
|
||||
if self._precision_input == "32-true":
|
||||
return XLAPrecision()
|
||||
if self._precision_input in ("16-mixed", "bf16-mixed"):
|
||||
if self._precision_input == "16-mixed":
|
||||
rank_zero_warn(
|
||||
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
|
||||
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
|
||||
)
|
||||
return XLABf16Precision()
|
||||
|
||||
return XLAPrecision(self._precision_input) # type: ignore
|
||||
if isinstance(self.strategy, DeepSpeedStrategy):
|
||||
return DeepSpeedPrecision(self._precision_input) # type: ignore
|
||||
if isinstance(self.strategy, FSDPStrategy):
|
||||
|
@ -477,18 +466,15 @@ class _Connector:
|
|||
|
||||
def _validate_precision_choice(self) -> None:
|
||||
"""Validate the combination of choices for precision, and accelerator."""
|
||||
if isinstance(self.accelerator, XLAAccelerator):
|
||||
if self._precision_input == "64-true":
|
||||
raise NotImplementedError(
|
||||
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
|
||||
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
|
||||
" requesting this feature."
|
||||
)
|
||||
if self._precision_instance and not isinstance(self._precision_instance, (XLAPrecision, XLABf16Precision)):
|
||||
raise ValueError(
|
||||
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
|
||||
f" found: {self._precision_instance}."
|
||||
)
|
||||
if (
|
||||
isinstance(self.accelerator, XLAAccelerator)
|
||||
and self._precision_instance
|
||||
and not isinstance(self._precision_instance, XLAPrecision)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
|
||||
f" found: {self._precision_instance}."
|
||||
)
|
||||
|
||||
def _lazy_init_strategy(self) -> None:
|
||||
"""Lazily set missing attributes on the previously instantiated strategy."""
|
||||
|
|
|
@ -22,7 +22,6 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
|
|||
from lightning.fabric.plugins.precision.half import HalfPrecision
|
||||
from lightning.fabric.plugins.precision.precision import Precision
|
||||
from lightning.fabric.plugins.precision.xla import XLAPrecision
|
||||
from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision
|
||||
|
||||
__all__ = [
|
||||
"ClusterEnvironment",
|
||||
|
@ -35,6 +34,5 @@ __all__ = [
|
|||
"HalfPrecision",
|
||||
"MixedPrecision",
|
||||
"XLAPrecision",
|
||||
"XLABf16Precision",
|
||||
"FSDPPrecision",
|
||||
]
|
||||
|
|
|
@ -19,7 +19,6 @@ from lightning.fabric.plugins.precision.half import HalfPrecision
|
|||
from lightning.fabric.plugins.precision.precision import Precision
|
||||
from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision
|
||||
from lightning.fabric.plugins.precision.xla import XLAPrecision
|
||||
from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision
|
||||
|
||||
__all__ = [
|
||||
"DeepSpeedPrecision",
|
||||
|
@ -28,7 +27,6 @@ __all__ = [
|
|||
"MixedPrecision",
|
||||
"Precision",
|
||||
"XLAPrecision",
|
||||
"XLABf16Precision",
|
||||
"FSDPPrecision",
|
||||
"TransformerEnginePrecision",
|
||||
]
|
||||
|
|
|
@ -11,20 +11,59 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
from typing_extensions import get_args
|
||||
|
||||
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
|
||||
from lightning.fabric.plugins.precision.precision import Precision
|
||||
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
|
||||
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"]
|
||||
|
||||
|
||||
class XLAPrecision(Precision):
|
||||
"""Precision plugin with XLA."""
|
||||
"""Plugin for training with XLA.
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
Args:
|
||||
precision: Full precision (32-true) or half precision (16-true, bf16-true).
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If unsupported ``precision`` is provided.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, precision: _PRECISION_INPUT) -> None:
|
||||
if not _XLA_AVAILABLE:
|
||||
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
|
||||
super().__init__(*args, **kwargs)
|
||||
supported_precision = get_args(_PRECISION_INPUT)
|
||||
if precision not in supported_precision:
|
||||
raise ValueError(
|
||||
f"`precision={precision!r})` is not supported in XLA."
|
||||
f" `precision` must be one of: {supported_precision}."
|
||||
)
|
||||
self.precision = precision
|
||||
|
||||
if precision == "16-true":
|
||||
os.environ["XLA_USE_F16"] = "1"
|
||||
self._desired_dtype = torch.float16
|
||||
elif precision == "bf16-true":
|
||||
os.environ["XLA_USE_BF16"] = "1"
|
||||
self._desired_dtype = torch.bfloat16
|
||||
else:
|
||||
self._desired_dtype = torch.float32
|
||||
|
||||
def convert_input(self, data: Any) -> Any:
|
||||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
|
||||
|
||||
def convert_output(self, data: Any) -> Any:
|
||||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
|
@ -35,3 +74,7 @@ class XLAPrecision(Precision):
|
|||
|
||||
# you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
|
||||
return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
|
||||
|
||||
def teardown(self) -> None:
|
||||
os.environ.pop("XLA_USE_BF16", None)
|
||||
os.environ.pop("XLA_USE_F16", None)
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
||||
from lightning.fabric.plugins.precision import XLAPrecision
|
||||
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
|
||||
|
||||
|
||||
class XLABf16Precision(XLAPrecision):
|
||||
"""Plugin that enables mixed bf16 with XLA."""
|
||||
|
||||
precision: Literal["bf16-mixed"] = "bf16-mixed"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
os.environ["XLA_USE_BF16"] = "1"
|
||||
|
||||
def convert_input(self, data: Any) -> Any:
|
||||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.bfloat16)
|
||||
|
||||
def convert_output(self, data: Any) -> Any:
|
||||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
|
||||
|
||||
def teardown(self) -> None:
|
||||
os.environ.pop("XLA_USE_BF16", None)
|
|
@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))
|
||||
|
||||
|
||||
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18219](https://github.com/Lightning-AI/lightning/pull/18219))
|
||||
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18213](https://github.com/Lightning-AI/lightning/pull/18213), [#18219](https://github.com/Lightning-AI/lightning/pull/18219))
|
||||
|
||||
|
||||
- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Any
|
|||
import lightning.pytorch as pl
|
||||
from lightning.fabric.strategies import _StrategyRegistry
|
||||
from lightning.pytorch.accelerators.xla import XLAAccelerator
|
||||
from lightning.pytorch.plugins.precision import XLABf16PrecisionPlugin, XLAPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
|
||||
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
@ -29,6 +29,7 @@ def _patch_sys_modules() -> None:
|
|||
sys.modules["lightning.pytorch.accelerators.tpu"] = self
|
||||
sys.modules["lightning.pytorch.plugins.precision.tpu"] = self
|
||||
sys.modules["lightning.pytorch.plugins.precision.tpu_bf16"] = self
|
||||
sys.modules["lightning.pytorch.plugins.precision.xlabf16"] = self
|
||||
|
||||
|
||||
class SingleTPUStrategy(SingleDeviceXLAStrategy):
|
||||
|
@ -74,22 +75,37 @@ class TPUPrecisionPlugin(XLAPrecisionPlugin):
|
|||
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`"
|
||||
" instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(precision="32-true")
|
||||
|
||||
|
||||
class TPUBf16PrecisionPlugin(XLABf16PrecisionPlugin):
|
||||
class TPUBf16PrecisionPlugin(XLAPrecisionPlugin):
|
||||
"""Legacy class.
|
||||
|
||||
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLABf16PrecisionPlugin` instead.
|
||||
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `TPUBf16PrecisionPlugin` class is deprecated. Use"
|
||||
" `lightning.pytorch.plugins.precision.XLABf16PrecisionPlugin` instead."
|
||||
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(precision="bf16-true")
|
||||
|
||||
|
||||
class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
|
||||
"""Legacy class.
|
||||
|
||||
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `XLABf16PrecisionPlugin` class is deprecated. Use"
|
||||
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
|
||||
)
|
||||
super().__init__(precision="bf16-true")
|
||||
|
||||
|
||||
def _patch_classes() -> None:
|
||||
|
@ -99,6 +115,8 @@ def _patch_classes() -> None:
|
|||
setattr(pl.plugins.precision, "TPUPrecisionPlugin", TPUPrecisionPlugin)
|
||||
setattr(pl.plugins, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
|
||||
setattr(pl.plugins.precision, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
|
||||
setattr(pl.plugins, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
|
||||
setattr(pl.plugins.precision, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
|
||||
|
||||
|
||||
_patch_sys_modules()
|
||||
|
|
|
@ -10,7 +10,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F
|
|||
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
|
||||
|
||||
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
|
||||
PLUGIN_INPUT = Union[PLUGIN, str]
|
||||
|
@ -28,7 +27,6 @@ __all__ = [
|
|||
"FSDPMixedPrecisionPlugin",
|
||||
"FSDPPrecisionPlugin",
|
||||
"XLAPrecisionPlugin",
|
||||
"XLABf16PrecisionPlugin",
|
||||
"LayerSync",
|
||||
"TorchSyncBatchNorm",
|
||||
]
|
||||
|
|
|
@ -18,7 +18,6 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, F
|
|||
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
|
||||
|
||||
__all__ = [
|
||||
"DeepSpeedPrecisionPlugin",
|
||||
|
@ -29,5 +28,4 @@ __all__ = [
|
|||
"MixedPrecisionPlugin",
|
||||
"PrecisionPlugin",
|
||||
"XLAPrecisionPlugin",
|
||||
"XLABf16PrecisionPlugin",
|
||||
]
|
||||
|
|
|
@ -11,30 +11,53 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from typing_extensions import get_args
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
|
||||
from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class XLAPrecisionPlugin(PrecisionPlugin):
|
||||
"""Precision plugin with XLA."""
|
||||
"""Plugin for training with XLA.
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
Args:
|
||||
precision: Full precision (32-true) or half precision (16-true, bf16-true).
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If unsupported ``precision`` is provided.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
|
||||
if not _XLA_AVAILABLE:
|
||||
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
|
||||
import torch_xla.core.xla_model as xm
|
||||
supported_precision = get_args(_PRECISION_INPUT)
|
||||
if precision not in supported_precision:
|
||||
raise ValueError(
|
||||
f"`precision={precision!r})` is not supported in XLA."
|
||||
f" `precision` must be one of: {supported_precision}."
|
||||
)
|
||||
self.precision = precision
|
||||
|
||||
closure_result = closure()
|
||||
xm.reduce_gradients(optimizer)
|
||||
return closure_result
|
||||
if precision == "16-true":
|
||||
os.environ["XLA_USE_F16"] = "1"
|
||||
self._desired_dtype = torch.float16
|
||||
elif precision == "bf16-true":
|
||||
os.environ["XLA_USE_BF16"] = "1"
|
||||
self._desired_dtype = torch.bfloat16
|
||||
else:
|
||||
self._desired_dtype = torch.float32
|
||||
|
||||
def optimizer_step( # type: ignore[override]
|
||||
self,
|
||||
|
@ -45,7 +68,7 @@ class XLAPrecisionPlugin(PrecisionPlugin):
|
|||
) -> Any:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
closure = partial(self._tpu_wrap_closure, optimizer, closure)
|
||||
closure = partial(self._xla_wrap_closure, optimizer, closure)
|
||||
closure = partial(self._wrap_closure, model, optimizer, closure)
|
||||
closure_result = optimizer.step(closure=closure, **kwargs)
|
||||
xm.mark_step()
|
||||
|
@ -59,3 +82,14 @@ class XLAPrecisionPlugin(PrecisionPlugin):
|
|||
" requesting this feature."
|
||||
)
|
||||
return closure_result
|
||||
|
||||
def teardown(self) -> None:
|
||||
os.environ.pop("XLA_USE_BF16", None)
|
||||
os.environ.pop("XLA_USE_F16", None)
|
||||
|
||||
def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
closure_result = closure()
|
||||
xm.reduce_gradients(optimizer)
|
||||
return closure_result
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
|
||||
|
||||
|
||||
class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
|
||||
"""Plugin that enables mixed bf16 with XLA."""
|
||||
|
||||
precision: Literal["bf16-mixed"] = "bf16-mixed"
|
||||
|
||||
def connect(
|
||||
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
|
||||
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
|
||||
os.environ["XLA_USE_BF16"] = "1"
|
||||
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
|
||||
|
||||
def teardown(self) -> None:
|
||||
os.environ.pop("XLA_USE_BF16", None)
|
|
@ -44,7 +44,6 @@ from lightning.pytorch.plugins import (
|
|||
MixedPrecisionPlugin,
|
||||
PLUGIN_INPUT,
|
||||
PrecisionPlugin,
|
||||
XLABf16PrecisionPlugin,
|
||||
XLAPrecisionPlugin,
|
||||
)
|
||||
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
|
||||
|
@ -507,16 +506,6 @@ class _AcceleratorConnector:
|
|||
|
||||
if isinstance(self.accelerator, HPUAccelerator):
|
||||
return HPUPrecisionPlugin(self._precision_flag)
|
||||
if isinstance(self.accelerator, XLAAccelerator):
|
||||
if self._precision_flag == "32-true":
|
||||
return XLAPrecisionPlugin()
|
||||
if self._precision_flag in ("16-mixed", "bf16-mixed"):
|
||||
if self._precision_flag == "16-mixed":
|
||||
rank_zero_warn(
|
||||
"You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
|
||||
" is not supported on TPUs. Using `precision='bf16-mixed'` instead."
|
||||
)
|
||||
return XLABf16PrecisionPlugin()
|
||||
|
||||
if _LIGHTNING_COLOSSALAI_AVAILABLE:
|
||||
from lightning_colossalai import ColossalAIPrecisionPlugin, ColossalAIStrategy
|
||||
|
@ -524,6 +513,8 @@ class _AcceleratorConnector:
|
|||
if isinstance(self.strategy, ColossalAIStrategy):
|
||||
return ColossalAIPrecisionPlugin(self._precision_flag)
|
||||
|
||||
if isinstance(self.accelerator, XLAAccelerator):
|
||||
return XLAPrecisionPlugin(self._precision_flag) # type: ignore
|
||||
if isinstance(self.strategy, DeepSpeedStrategy):
|
||||
return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type]
|
||||
if isinstance(self.strategy, FSDPStrategy):
|
||||
|
@ -553,20 +544,15 @@ class _AcceleratorConnector:
|
|||
|
||||
def _validate_precision_choice(self) -> None:
|
||||
"""Validate the combination of choices for precision, AMP type, and accelerator."""
|
||||
if isinstance(self.accelerator, XLAAccelerator):
|
||||
if self._precision_flag == "64-true":
|
||||
raise MisconfigurationException(
|
||||
"`Trainer(accelerator='tpu', precision='64-true')` is not implemented."
|
||||
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
|
||||
" requesting this feature."
|
||||
)
|
||||
if self._precision_plugin_flag and not isinstance(
|
||||
self._precision_plugin_flag, (XLAPrecisionPlugin, XLABf16PrecisionPlugin)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
|
||||
f" found: {self._precision_plugin_flag}."
|
||||
)
|
||||
if (
|
||||
isinstance(self.accelerator, XLAAccelerator)
|
||||
and self._precision_plugin_flag
|
||||
and not isinstance(self._precision_plugin_flag, XLAPrecisionPlugin)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The `XLAAccelerator` can only be used with a `XLAPrecisionPlugin`,"
|
||||
f" found: {self._precision_plugin_flag}."
|
||||
)
|
||||
if _lightning_habana_available():
|
||||
from lightning_habana import HPUAccelerator
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name):
|
|||
("lightning.fabric.plugins", "TPUBf16Precision"),
|
||||
("lightning.fabric.plugins.precision", "TPUBf16Precision"),
|
||||
("lightning.fabric.plugins.precision.tpu_bf16", "TPUBf16Precision"),
|
||||
("lightning.fabric.plugins.precision", "XLABf16Precision"),
|
||||
("lightning.fabric.plugins.precision.xlabf16", "XLABf16Precision"),
|
||||
],
|
||||
)
|
||||
def test_graveyard_no_device(import_path, name):
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lightning.fabric.plugins import XLAPrecision
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_precision_input_validation(xla_available):
|
||||
XLAPrecision(precision="32-true")
|
||||
XLAPrecision(precision="16-true")
|
||||
XLAPrecision(precision="bf16-true")
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")):
|
||||
XLAPrecision("16")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")):
|
||||
XLAPrecision("16-mixed")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")):
|
||||
XLAPrecision("bf16-mixed")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")):
|
||||
XLAPrecision("64-true")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "expected_dtype"),
|
||||
[
|
||||
("bf16-true", torch.bfloat16),
|
||||
("16-true", torch.half),
|
||||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_selected_dtype(precision, expected_dtype, xla_available):
|
||||
plugin = XLAPrecision(precision=precision)
|
||||
assert plugin.precision == precision
|
||||
assert plugin._desired_dtype == expected_dtype
|
||||
|
||||
|
||||
def test_teardown(xla_available):
|
||||
plugin = XLAPrecision(precision="16-true")
|
||||
assert os.environ["XLA_USE_F16"] == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_B16" not in os.environ
|
||||
|
||||
plugin = XLAPrecision(precision="bf16-true")
|
||||
assert os.environ["XLA_USE_BF16"] == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_BF16" not in os.environ
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from lightning.fabric.plugins import XLABf16Precision
|
||||
|
||||
|
||||
def test_teardown(xla_available):
|
||||
plugin = XLABf16Precision()
|
||||
assert os.environ.get("XLA_USE_BF16") == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_BF16" not in os.environ
|
|
@ -655,16 +655,9 @@ def test_strategy_choice_ddp_cpu_slurm(strategy):
|
|||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
|
||||
def test_unsupported_tpu_choice(_, tpu_available):
|
||||
with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"):
|
||||
_Connector(accelerator="tpu", precision="64-true")
|
||||
|
||||
# if user didn't set strategy, _Connector will choose the TPUSingleStrategy or XLAStrategy
|
||||
with pytest.raises(
|
||||
ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"
|
||||
), pytest.warns(
|
||||
UserWarning, match=r"accelerator='tpu', precision='16-mixed'\)` but AMP with fp16 is not supported"
|
||||
):
|
||||
_Connector(accelerator="tpu", precision="16-mixed", strategy="ddp")
|
||||
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy
|
||||
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
|
||||
_Connector(accelerator="tpu", precision="16-true", strategy="ddp")
|
||||
|
||||
# wrong precision plugin type
|
||||
strategy = XLAStrategy(accelerator=XLAAccelerator(), precision=Precision())
|
||||
|
@ -672,7 +665,7 @@ def test_unsupported_tpu_choice(_, tpu_available):
|
|||
_Connector(strategy=strategy)
|
||||
|
||||
# wrong strategy type
|
||||
strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision())
|
||||
strategy = DDPStrategy(accelerator=XLAAccelerator(), precision=XLAPrecision(precision="16-true"))
|
||||
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
|
||||
_Connector(strategy=strategy)
|
||||
|
||||
|
@ -1043,6 +1036,7 @@ def test_connector_auto_selection(monkeypatch, is_interactive):
|
|||
assert connector.strategy.launcher.is_interactive_compatible
|
||||
|
||||
|
||||
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
|
||||
def test_xla_fsdp_automatic_strategy_selection(monkeypatch, tpu_available):
|
||||
import lightning.fabric.strategies as strategies
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ def test_graveyard_single_tpu(import_path, name):
|
|||
("lightning.pytorch.plugins", "TPUBf16PrecisionPlugin"),
|
||||
("lightning.pytorch.plugins.precision", "TPUBf16PrecisionPlugin"),
|
||||
("lightning.pytorch.plugins.precision.tpu_bf16", "TPUBf16PrecisionPlugin"),
|
||||
("lightning.pytorch.plugins.precision", "XLABf16PrecisionPlugin"),
|
||||
("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"),
|
||||
],
|
||||
)
|
||||
def test_graveyard_no_device(import_path, name):
|
||||
|
|
|
@ -102,7 +102,7 @@ def test_model_multiple_tpu_devices(tmpdir):
|
|||
def test_model_16bit_tpu_devices_1(tmpdir):
|
||||
trainer_options = {
|
||||
"default_root_dir": tmpdir,
|
||||
"precision": "16-mixed",
|
||||
"precision": "16-true",
|
||||
"enable_progress_bar": False,
|
||||
"max_epochs": 2,
|
||||
"accelerator": "tpu",
|
||||
|
@ -121,7 +121,7 @@ def test_model_16bit_tpu_devices_1(tmpdir):
|
|||
def test_model_16bit_tpu_index(tmpdir, tpu_core):
|
||||
trainer_options = {
|
||||
"default_root_dir": tmpdir,
|
||||
"precision": "16-mixed",
|
||||
"precision": "16-true",
|
||||
"enable_progress_bar": False,
|
||||
"max_epochs": 2,
|
||||
"accelerator": "tpu",
|
||||
|
@ -143,7 +143,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
|
|||
def test_model_16bit_multiple_tpu_devices(tmpdir):
|
||||
trainer_options = {
|
||||
"default_root_dir": tmpdir,
|
||||
"precision": "16-mixed",
|
||||
"precision": "16-true",
|
||||
"enable_progress_bar": False,
|
||||
"max_epochs": 1,
|
||||
"accelerator": "tpu",
|
||||
|
|
|
@ -11,18 +11,66 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lightning.pytorch.plugins import XLAPrecisionPlugin
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_optimizer_step_calls_mark_step():
|
||||
plugin = XLAPrecisionPlugin()
|
||||
plugin = XLAPrecisionPlugin(precision="32-true")
|
||||
optimizer = Mock()
|
||||
with mock.patch("torch_xla.core.xla_model") as xm_mock:
|
||||
plugin.optimizer_step(optimizer=optimizer, model=Mock(), closure=Mock())
|
||||
optimizer.step.assert_called_once()
|
||||
xm_mock.mark_step.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_precision_input_validation(xla_available):
|
||||
XLAPrecisionPlugin(precision="32-true")
|
||||
XLAPrecisionPlugin(precision="16-true")
|
||||
XLAPrecisionPlugin(precision="bf16-true")
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")):
|
||||
XLAPrecisionPlugin("16")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")):
|
||||
XLAPrecisionPlugin("16-mixed")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")):
|
||||
XLAPrecisionPlugin("bf16-mixed")
|
||||
with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")):
|
||||
XLAPrecisionPlugin("64-true")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "expected_dtype"),
|
||||
[
|
||||
("bf16-true", torch.bfloat16),
|
||||
("16-true", torch.half),
|
||||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_selected_dtype(precision, expected_dtype, xla_available):
|
||||
plugin = XLAPrecisionPlugin(precision=precision)
|
||||
assert plugin.precision == precision
|
||||
assert plugin._desired_dtype == expected_dtype
|
||||
|
||||
|
||||
def test_teardown(xla_available):
|
||||
plugin = XLAPrecisionPlugin(precision="16-true")
|
||||
assert os.environ["XLA_USE_F16"] == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_B16" not in os.environ
|
||||
|
||||
plugin = XLAPrecisionPlugin(precision="bf16-true")
|
||||
assert os.environ["XLA_USE_BF16"] == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_BF16" not in os.environ
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
from lightning.pytorch.plugins import XLABf16PrecisionPlugin
|
||||
|
||||
|
||||
def test_teardown(xla_available):
|
||||
plugin = XLABf16PrecisionPlugin()
|
||||
plugin.connect(Mock(), Mock(), Mock())
|
||||
assert os.environ.get("XLA_USE_BF16") == "1"
|
||||
plugin.teardown()
|
||||
assert "XLA_USE_BF16" not in os.environ
|
|
@ -549,17 +549,11 @@ def test_check_fsdp_strategy_and_fallback():
|
|||
Trainer(accelerator="cpu", strategy="fsdp")
|
||||
|
||||
|
||||
def test_unsupported_tpu_choice(tpu_available):
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"
|
||||
):
|
||||
Trainer(accelerator="tpu", precision="64-true")
|
||||
|
||||
# if user didn't set strategy, AcceleratorConnector will choose "single_xla" or "xla"
|
||||
with pytest.raises(
|
||||
ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"
|
||||
), pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16-mixed\)` but AMP with fp16 is not supported"):
|
||||
Trainer(accelerator="tpu", precision="16-mixed", strategy="ddp")
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_unsupported_tpu_choice(xla_available, tpu_available):
|
||||
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy
|
||||
with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"):
|
||||
Trainer(accelerator="tpu", precision="16-true", strategy="ddp")
|
||||
|
||||
|
||||
def mock_ipu_available(monkeypatch, value=True):
|
||||
|
|
Loading…
Reference in New Issue