diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index b76af7a22d..cdbcb6bad1 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -14,6 +14,7 @@ from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401 +from lightning_lite.strategies.single_tpu import SingleTPUStrategy # noqa: F401 from lightning_lite.strategies.strategy import Strategy # noqa: F401 STRATEGY_REGISTRY = _StrategyRegistry() diff --git a/src/lightning_lite/strategies/single_tpu.py b/src/lightning_lite/strategies/single_tpu.py new file mode 100644 index 0000000000..fe1dad21a7 --- /dev/null +++ b/src/lightning_lite/strategies/single_tpu.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional + +from lightning_lite.accelerators import Accelerator +from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO +from lightning_lite.plugins.precision import Precision +from lightning_lite.strategies.single_device import SingleDeviceStrategy + + +class SingleTPUStrategy(SingleDeviceStrategy): + """Strategy for training on a single TPU device.""" + + def __init__( + self, + device: int, + accelerator: Optional[Accelerator] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + ): + import torch_xla.core.xla_model as xm + + super().__init__( + accelerator=accelerator, + device=xm.xla_device(device), + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + + @property + def checkpoint_io(self) -> CheckpointIO: + if self._checkpoint_io is None: + self._checkpoint_io = XLACheckpointIO() + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + self._checkpoint_io = io + + @property + def is_distributed(self) -> bool: + return False + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register("single_tpu", cls, description=f"{cls.__class__.__name__}") diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 7d6edfd449..d94198d571 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -41,4 +41,6 @@ def test_strategy_registry_with_new_strategy(): def test_available_strategies_in_registry(): - assert STRATEGY_REGISTRY.available_strategies() == [] + assert STRATEGY_REGISTRY.available_strategies() == [ + "single_tpu", + ]