Add backward-compatibility for LightningLite in PL (#14735)
This commit is contained in:
parent
e3e71670e6
commit
c0ff7a1b77
|
@ -38,10 +38,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics.classification import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite # import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite # import LightningLite
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -34,10 +34,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.loops import Loop
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
|
|
@ -52,11 +52,11 @@ from lightning_lite.strategies import (
|
|||
XLAStrategy,
|
||||
)
|
||||
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
|
||||
from lightning_lite.utilities import _StrategyType, rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
|
||||
from lightning_lite.utilities.device_parser import determine_root_gpu_device
|
||||
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
|
||||
|
||||
_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO]
|
||||
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
|
||||
_PLUGIN_INPUT = Union[_PLUGIN, str]
|
||||
|
||||
|
||||
|
@ -99,8 +99,6 @@ class _Connector:
|
|||
num_nodes: int = 1,
|
||||
precision: Union[int, str] = 32,
|
||||
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
|
||||
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
|
||||
gpus: Optional[Union[List[int], str, int]] = None, # deprecated
|
||||
) -> None:
|
||||
# 1. Parsing flags
|
||||
# Get registered strategies, built-in accelerators and precision plugins
|
||||
|
@ -125,9 +123,7 @@ class _Connector:
|
|||
precision=precision,
|
||||
plugins=plugins,
|
||||
)
|
||||
self._check_device_config_and_set_final_flags(
|
||||
devices=devices, num_nodes=num_nodes, gpus=gpus, tpu_cores=tpu_cores
|
||||
)
|
||||
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
|
||||
|
||||
# 2. Instantiate Accelerator
|
||||
# handle `auto`, `None` and `gpu`
|
||||
|
@ -278,11 +274,7 @@ class _Connector:
|
|||
self._parallel_devices = self._strategy_flag.parallel_devices
|
||||
|
||||
def _check_device_config_and_set_final_flags(
|
||||
self,
|
||||
devices: Optional[Union[List[int], str, int]],
|
||||
num_nodes: int,
|
||||
gpus: Optional[Union[List[int], str, int]],
|
||||
tpu_cores: Optional[Union[List[int], str, int]],
|
||||
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
|
||||
) -> None:
|
||||
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
|
||||
self._devices_flag = devices
|
||||
|
@ -298,56 +290,12 @@ class _Connector:
|
|||
f" using {accelerator_name} accelerator."
|
||||
)
|
||||
|
||||
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
|
||||
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(devices, gpus, tpu_cores)
|
||||
|
||||
if self._devices_flag == "auto" and self._accelerator_flag is None:
|
||||
raise ValueError(
|
||||
f"You passed `devices={devices}` but haven't specified"
|
||||
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
|
||||
)
|
||||
|
||||
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
|
||||
self,
|
||||
devices: Optional[Union[List[int], str, int]],
|
||||
gpus: Optional[Union[List[int], str, int]],
|
||||
tpu_cores: Optional[Union[List[int], str, int]],
|
||||
) -> None:
|
||||
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
|
||||
`accelerator_flag`."""
|
||||
if gpus is not None:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
|
||||
f" in v2.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
|
||||
)
|
||||
if tpu_cores is not None:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
|
||||
f" in v2.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
|
||||
)
|
||||
self._gpus: Optional[Union[List[int], str, int]] = gpus
|
||||
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
|
||||
deprecated_devices_specific_flag = gpus or tpu_cores
|
||||
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
|
||||
if devices:
|
||||
# TODO: improve error message
|
||||
rank_zero_warn(
|
||||
f"The flag `devices={devices}` will be ignored, "
|
||||
f"instead the device specific number {deprecated_devices_specific_flag} will be used"
|
||||
)
|
||||
|
||||
if [(gpus is not None), (tpu_cores is not None)].count(True) > 1:
|
||||
# TODO: improve error message
|
||||
rank_zero_warn("more than one device specific flag has been set")
|
||||
self._devices_flag = deprecated_devices_specific_flag
|
||||
|
||||
if self._accelerator_flag is None:
|
||||
# set accelerator type based on num_processes, gpus, ipus, tpu_cores
|
||||
if tpu_cores:
|
||||
self._accelerator_flag = "tpu"
|
||||
if gpus:
|
||||
self._accelerator_flag = "cuda"
|
||||
|
||||
def _choose_auto_accelerator(self) -> str:
|
||||
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
|
||||
if self._accelerator_flag == "auto":
|
||||
|
@ -392,9 +340,6 @@ class _Connector:
|
|||
|
||||
self._set_devices_flag_if_auto_passed()
|
||||
|
||||
self._gpus = self._devices_flag if not self._gpus else self._gpus
|
||||
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores
|
||||
|
||||
self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
|
||||
if not self._parallel_devices:
|
||||
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
|
||||
|
|
|
@ -64,8 +64,6 @@ class LightningLite(ABC):
|
|||
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
|
||||
or bfloat16 precision (``"bf16"``).
|
||||
plugins: One or several custom plugins
|
||||
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
|
||||
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -76,8 +74,6 @@ class LightningLite(ABC):
|
|||
num_nodes: int = 1,
|
||||
precision: Union[int, str] = 32,
|
||||
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
|
||||
gpus: Optional[Union[List[int], str, int]] = None,
|
||||
tpu_cores: Optional[Union[List[int], str, int]] = None,
|
||||
) -> None:
|
||||
self._connector = _Connector(
|
||||
accelerator=accelerator,
|
||||
|
@ -86,8 +82,6 @@ class LightningLite(ABC):
|
|||
num_nodes=num_nodes,
|
||||
precision=precision,
|
||||
plugins=plugins,
|
||||
tpu_cores=tpu_cores,
|
||||
gpus=gpus,
|
||||
)
|
||||
self._strategy: Strategy = self._connector.strategy
|
||||
self._accelerator: Accelerator = self._connector.accelerator
|
||||
|
|
|
@ -54,7 +54,6 @@ class XLAStrategy(DDPSpawnStrategy):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[Precision] = None,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
accelerator=accelerator,
|
||||
|
|
|
@ -261,7 +261,9 @@ class EarlyStopping(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
|
||||
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
|
||||
rank = _get_rank(
|
||||
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
|
||||
)
|
||||
if trainer is not None and trainer.world_size <= 1:
|
||||
rank = None
|
||||
message = rank_prefixed_message(message, rank)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.lite.lite import LightningLite
|
||||
|
||||
__all__ = ["LightningLite"]
|
|
@ -0,0 +1,308 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
|
||||
from lightning_lite.lite import LightningLite as _NewLightningLite
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision
|
||||
from lightning_lite.plugins import DoublePrecision as LiteDoublePrecision
|
||||
from lightning_lite.plugins import NativeMixedPrecision as LiteNativeMixedPrecision
|
||||
from lightning_lite.plugins import Precision as LitePrecision
|
||||
from lightning_lite.plugins import TPUBf16Precision as LiteTPUBf16Precision
|
||||
from lightning_lite.plugins import TPUPrecision as LiteTPUPrecision
|
||||
from lightning_lite.strategies import DataParallelStrategy as LiteDataParallelStrategy
|
||||
from lightning_lite.strategies import DDPShardedStrategy as LiteDDPShardedStrategy
|
||||
from lightning_lite.strategies import DDPSpawnShardedStrategy as LiteDDPSpawnShardedStrategy
|
||||
from lightning_lite.strategies import DDPSpawnStrategy as LiteDDPSpawnStrategy
|
||||
from lightning_lite.strategies import DDPStrategy as LiteDDPStrategy
|
||||
from lightning_lite.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
|
||||
from lightning_lite.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
|
||||
from lightning_lite.strategies import SingleTPUStrategy as LiteSingleTPUStrategy
|
||||
from lightning_lite.strategies import Strategy as LiteStrategy
|
||||
from lightning_lite.strategies import XLAStrategy
|
||||
from pytorch_lightning.accelerators import Accelerator as PLAccelerator
|
||||
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin as PLDeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
|
||||
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin as PLNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
|
||||
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin as PLTPUBf16PrecisionPlugin
|
||||
from pytorch_lightning.plugins import TPUPrecisionPlugin as PLTPUPrecisionPlugin
|
||||
from pytorch_lightning.strategies import DataParallelStrategy as PLDataParallelStrategy
|
||||
from pytorch_lightning.strategies import DDPShardedStrategy as PLDDPShardedStrategy
|
||||
from pytorch_lightning.strategies import DDPSpawnShardedStrategy as PLDDPSpawnShardedStrategy
|
||||
from pytorch_lightning.strategies import DDPSpawnStrategy as PLDDPSpawnStrategy
|
||||
from pytorch_lightning.strategies import DDPStrategy as PLDDPStrategy
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy as PLDeepSpeedStrategy
|
||||
from pytorch_lightning.strategies import SingleDeviceStrategy as PLSingleDeviceStrategy
|
||||
from pytorch_lightning.strategies import SingleTPUStrategy as PLSingleTPUStrategy
|
||||
from pytorch_lightning.strategies import Strategy as PLStrategy
|
||||
from pytorch_lightning.strategies import TPUSpawnStrategy as PLTPUSpawnStrategy
|
||||
|
||||
_PL_PLUGIN = Union[PLPrecisionPlugin, ClusterEnvironment, CheckpointIO]
|
||||
_PL_PLUGIN_INPUT = Union[_PL_PLUGIN, str]
|
||||
|
||||
|
||||
class LightningLite(_NewLightningLite, ABC):
|
||||
"""Lite accelerates your PyTorch training or inference code with minimal changes required.
|
||||
|
||||
- Automatic placement of models and data onto the device.
|
||||
- Automatic support for mixed and double precision (smaller memory footprint).
|
||||
- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
|
||||
(data-parallel training, sharded training, etc.).
|
||||
- Automated spawning of processes, no launch utilities required.
|
||||
- Multi-node support.
|
||||
|
||||
Args:
|
||||
accelerator: The hardware to run on. Possible choices are:
|
||||
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
||||
strategy: Strategy for how to run across multiple devices. Possible choices are:
|
||||
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
|
||||
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
|
||||
The value applies per node.
|
||||
num_nodes: Number of GPU nodes for distributed training.
|
||||
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
|
||||
or bfloat16 precision (``"bf16"``).
|
||||
plugins: One or several custom plugins
|
||||
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
|
||||
|
||||
.. deprecated:: v1.8.0
|
||||
``gpus`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
|
||||
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
|
||||
|
||||
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
|
||||
|
||||
.. deprecated:: v1.8.0
|
||||
``tpu_cores`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
|
||||
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
accelerator: Optional[Union[str, PLAccelerator]] = None,
|
||||
strategy: Optional[Union[str, PLStrategy]] = None,
|
||||
devices: Optional[Union[List[int], str, int]] = None,
|
||||
num_nodes: int = 1,
|
||||
precision: Union[int, str] = 32,
|
||||
plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None,
|
||||
gpus: Optional[Union[List[int], str, int]] = None,
|
||||
tpu_cores: Optional[Union[List[int], str, int]] = None,
|
||||
) -> None:
|
||||
|
||||
if gpus is not None or tpu_cores is not None:
|
||||
devices, accelerator = _convert_deprecated_device_flags(
|
||||
accelerator=accelerator,
|
||||
devices=devices,
|
||||
gpus=gpus,
|
||||
tpu_cores=tpu_cores,
|
||||
)
|
||||
|
||||
lite_plugins: Optional[Union[_LITE_PLUGIN_INPUT, List[_LITE_PLUGIN_INPUT]]]
|
||||
if isinstance(plugins, PLPrecisionPlugin):
|
||||
lite_plugins = _to_lite_precision_plugin(plugins)
|
||||
elif isinstance(plugins, list):
|
||||
lite_plugins = [
|
||||
_to_lite_precision_plugin(plugin) if isinstance(plugin, PLPrecisionPlugin) else plugin
|
||||
for plugin in plugins
|
||||
]
|
||||
else:
|
||||
lite_plugins = plugins
|
||||
|
||||
super().__init__(
|
||||
accelerator=accelerator,
|
||||
strategy=(_to_lite_strategy(strategy) if isinstance(strategy, PLStrategy) else strategy),
|
||||
devices=devices,
|
||||
num_nodes=num_nodes,
|
||||
precision=precision,
|
||||
plugins=lite_plugins,
|
||||
)
|
||||
|
||||
|
||||
def _convert_deprecated_device_flags(
|
||||
accelerator: Optional[Union[str, PLAccelerator]],
|
||||
devices: Optional[Union[List[int], str, int]],
|
||||
gpus: Optional[Union[List[int], str, int]],
|
||||
tpu_cores: Optional[Union[List[int], str, int]],
|
||||
) -> Tuple[Optional[Union[List[int], str, int]], Optional[Union[str, PLAccelerator]]]:
|
||||
"""Emit deprecation warnings for gpus and tpu_cores and translate them into the new accelerator and devices.
|
||||
|
||||
Similar implementation as in ``pytorch_lightning.trainer.connectors.accelerator_connector``.
|
||||
"""
|
||||
if gpus is not None:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.8.0 and will be removed"
|
||||
f" in v1.10.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
|
||||
)
|
||||
if tpu_cores is not None:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.8.0 and will be removed"
|
||||
f" in v1.10.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
|
||||
)
|
||||
deprecated_devices_specific_flag = gpus or tpu_cores
|
||||
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
|
||||
if devices:
|
||||
rank_zero_warn(
|
||||
f"The option `devices={devices}` will be ignored and the device specific number"
|
||||
f"{deprecated_devices_specific_flag} will be used instead."
|
||||
)
|
||||
|
||||
if gpus is not None and tpu_cores is not None:
|
||||
rank_zero_warn(
|
||||
f"Both `Lite(gpus={gpus!r}, tpu_cores={tpu_cores!r})` were specified. Please choose only one of"
|
||||
" the two."
|
||||
)
|
||||
|
||||
if accelerator is None:
|
||||
if tpu_cores:
|
||||
accelerator = "tpu"
|
||||
if gpus:
|
||||
accelerator = "cuda"
|
||||
|
||||
return deprecated_devices_specific_flag, accelerator
|
||||
|
||||
|
||||
def _to_lite_strategy(strategy: PLStrategy) -> LiteStrategy:
|
||||
"""Re-instantiates a PL-Strategy as the corresponding Lite-Strategy."""
|
||||
|
||||
if type(strategy) is PLDDPStrategy:
|
||||
return LiteDDPStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
cluster_environment=strategy.cluster_environment,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
process_group_backend=strategy.process_group_backend,
|
||||
timeout=strategy._timeout,
|
||||
**strategy._ddp_kwargs,
|
||||
)
|
||||
|
||||
if type(strategy) is PLDDPSpawnStrategy:
|
||||
return LiteDDPSpawnStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
cluster_environment=strategy.cluster_environment,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
process_group_backend=strategy.process_group_backend,
|
||||
timeout=strategy._timeout,
|
||||
start_method=strategy._start_method,
|
||||
**strategy._ddp_kwargs,
|
||||
)
|
||||
|
||||
if type(strategy) is PLTPUSpawnStrategy:
|
||||
return XLAStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
)
|
||||
|
||||
if type(strategy) is PLDeepSpeedStrategy:
|
||||
return LiteDeepSpeedStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
cluster_environment=strategy.cluster_environment,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
process_group_backend=strategy.process_group_backend,
|
||||
config=strategy.config,
|
||||
remote_device=strategy.remote_device,
|
||||
load_full_weights=strategy.load_full_weights,
|
||||
loss_scale=strategy.loss_scale,
|
||||
initial_scale_power=strategy.initial_scale_power,
|
||||
loss_scale_window=strategy.loss_scale_window,
|
||||
hysteresis=strategy.hysteresis,
|
||||
min_loss_scale=strategy.min_loss_scale,
|
||||
)
|
||||
|
||||
if type(strategy) is PLDataParallelStrategy:
|
||||
return LiteDataParallelStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
)
|
||||
|
||||
if type(strategy) is PLDDPShardedStrategy:
|
||||
return LiteDDPShardedStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
cluster_environment=strategy.cluster_environment,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
process_group_backend=strategy.process_group_backend,
|
||||
timeout=strategy._timeout,
|
||||
**strategy._ddp_kwargs,
|
||||
)
|
||||
|
||||
if type(strategy) is PLDDPSpawnShardedStrategy:
|
||||
return LiteDDPSpawnShardedStrategy(
|
||||
accelerator=strategy.accelerator,
|
||||
parallel_devices=strategy.parallel_devices,
|
||||
cluster_environment=strategy.cluster_environment,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
process_group_backend=strategy.process_group_backend,
|
||||
timeout=strategy._timeout,
|
||||
start_method=strategy._start_method,
|
||||
**strategy._ddp_kwargs,
|
||||
)
|
||||
|
||||
if type(strategy) is PLSingleDeviceStrategy:
|
||||
return LiteSingleDeviceStrategy(
|
||||
device=strategy.root_device,
|
||||
accelerator=strategy.accelerator,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
)
|
||||
|
||||
if type(strategy) is PLSingleTPUStrategy:
|
||||
return LiteSingleTPUStrategy(
|
||||
device=strategy.root_device.index,
|
||||
accelerator=strategy.accelerator,
|
||||
checkpoint_io=strategy.checkpoint_io,
|
||||
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
|
||||
)
|
||||
|
||||
|
||||
def _to_lite_precision_plugin(plugin: Optional[PLPrecisionPlugin]) -> LitePrecision:
|
||||
"""Re-instantiates a PL-PrecisionPlugin as the corresponding Lite-Precision plugin."""
|
||||
|
||||
if type(plugin) is PLPrecisionPlugin:
|
||||
return LitePrecision()
|
||||
|
||||
if type(plugin) is PLNativeMixedPrecisionPlugin:
|
||||
return LiteNativeMixedPrecision(precision=plugin.precision, device=plugin.device, scaler=plugin.scaler)
|
||||
|
||||
if type(plugin) is PLDoublePrecisionPlugin:
|
||||
return LiteDoublePrecision()
|
||||
|
||||
if type(plugin) is PLDeepSpeedPrecisionPlugin:
|
||||
return LiteDeepSpeedPrecision(precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level)
|
||||
|
||||
if type(plugin) is PLTPUPrecisionPlugin:
|
||||
return LiteTPUPrecision()
|
||||
|
||||
if type(plugin) is PLTPUBf16PrecisionPlugin:
|
||||
return LiteTPUBf16Precision()
|
||||
|
||||
# No backward compatibility for custom plugins / subclasses, as we can't re-instantiate these plugins
|
||||
raise TypeError(
|
||||
"You passed an unsupported plugin as input to Lite(plugins=...) or to a strategy. If you built a custom plugin,"
|
||||
" please change it to subclass the `lightning_lite.plugins.precision.Precision` class. Otherwise, please open"
|
||||
" an issue on the Lightning GitHub repository with your use case."
|
||||
)
|
|
@ -16,9 +16,8 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
|||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
|
||||
PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
|
||||
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
|
||||
PLUGIN_INPUT = Union[PLUGIN, str]
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -262,21 +262,11 @@ def test_accelerator_cpu(*_):
|
|||
connector = _Connector(accelerator="cpu")
|
||||
assert isinstance(connector.accelerator, CPUAccelerator)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="CUDAAccelerator can not run on your system since the accelerator is not available",
|
||||
):
|
||||
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"):
|
||||
_Connector(gpus=1)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="CUDAAccelerator can not run on your system since the accelerator is not available.",
|
||||
):
|
||||
_Connector(accelerator="cuda")
|
||||
|
||||
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"):
|
||||
_Connector(accelerator="cpu", gpus=1)
|
||||
_Connector(accelerator="cuda", devices=1)
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in v1.10.0."""
|
||||
from re import escape
|
||||
from unittest import mock
|
||||
|
||||
import numpy
|
||||
|
@ -19,10 +20,13 @@ import pytest
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_lite.accelerators import CUDAAccelerator as LiteCUDAAccelerator
|
||||
from lightning_lite.accelerators import TPUAccelerator as LiteTPUAccelerator
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule
|
||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
|
||||
|
@ -260,3 +264,24 @@ def test_v1_10_deprecated_seed_utilities():
|
|||
def test_v1_10_deprecated_accelerator_setup_environment_method():
|
||||
with pytest.deprecated_call(match="`Accelerator.setup_environment` has been deprecated in deprecated in v1.8.0"):
|
||||
CPUAccelerator().setup_environment(torch.device("cpu"))
|
||||
|
||||
|
||||
class EmptyLite(LightningLite):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_lite_convert_deprecated_gpus_argument(cuda_count_2):
|
||||
with pytest.deprecated_call(match=escape("Setting `Lite(gpus=2)` is deprecated in v1.8.0")):
|
||||
lite = EmptyLite(gpus=2)
|
||||
assert isinstance(lite._accelerator, LiteCUDAAccelerator)
|
||||
assert lite._connector._parallel_devices == [torch.device("cuda", 0), torch.device("cuda", 1)]
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@mock.patch("lightning_lite.accelerators.TPUAccelerator.is_available", return_value=True)
|
||||
def test_lite_convert_deprecated_tpus_argument(*_):
|
||||
with pytest.deprecated_call(match=escape("Setting `Lite(tpu_cores=8)` is deprecated in v1.8.0")):
|
||||
lite = EmptyLite(tpu_cores=8)
|
||||
assert isinstance(lite._accelerator, LiteTPUAccelerator)
|
||||
assert lite._connector._parallel_devices == list(range(8))
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
from re import escape
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning_lite.accelerators import CPUAccelerator as LiteCPUAccelerator
|
||||
from lightning_lite.plugins import DoublePrecision as LiteDoublePrecision
|
||||
from lightning_lite.plugins import Precision as LitePrecision
|
||||
from lightning_lite.plugins.environments import SLURMEnvironment
|
||||
from lightning_lite.strategies import DDPStrategy as LiteDDPStrategy
|
||||
from lightning_lite.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
|
||||
from lightning_lite.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
|
||||
from pytorch_lightning.accelerators import CUDAAccelerator as PLCUDAAccelerator
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
|
||||
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
|
||||
from pytorch_lightning.strategies import DDPStrategy as PLDDPStrategy
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy as PLDeepSpeedStrategy
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
class EmptyLite(LightningLite):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_lite_convert_pl_strategies_and_plugins(cuda_count_2):
|
||||
"""Tests a few examples of passing PL-accelerators/strategies/plugins to the soon deprecated PL version of
|
||||
Lightning Lite for backward compatibility.
|
||||
|
||||
Not all possible combinations of input arguments are tested.
|
||||
"""
|
||||
|
||||
# defaults
|
||||
lite = EmptyLite()
|
||||
assert isinstance(lite._accelerator, LiteCPUAccelerator)
|
||||
assert isinstance(lite._precision_plugin, LitePrecision)
|
||||
assert isinstance(lite._strategy, LiteSingleDeviceStrategy)
|
||||
|
||||
# accelerator and strategy passed separately
|
||||
lite = EmptyLite(accelerator=PLCUDAAccelerator(), strategy=PLDDPStrategy())
|
||||
assert isinstance(lite._accelerator, PLCUDAAccelerator)
|
||||
assert isinstance(lite._precision_plugin, LitePrecision)
|
||||
assert isinstance(lite._strategy, LiteDDPStrategy)
|
||||
|
||||
# accelerator passed to strategy
|
||||
lite = EmptyLite(strategy=PLDDPStrategy(accelerator=PLCUDAAccelerator()))
|
||||
assert isinstance(lite._accelerator, PLCUDAAccelerator)
|
||||
assert isinstance(lite._strategy, LiteDDPStrategy)
|
||||
|
||||
# kwargs passed to strategy
|
||||
lite = EmptyLite(strategy=PLDDPStrategy(find_unused_parameters=False))
|
||||
assert isinstance(lite._strategy, LiteDDPStrategy)
|
||||
assert lite._strategy._ddp_kwargs == dict(find_unused_parameters=False)
|
||||
|
||||
# plugins = instance
|
||||
lite = EmptyLite(plugins=PLDoublePrecisionPlugin())
|
||||
assert isinstance(lite._precision_plugin, LiteDoublePrecision)
|
||||
|
||||
# plugins = list
|
||||
lite = EmptyLite(plugins=[PLDoublePrecisionPlugin(), SLURMEnvironment()], devices=2)
|
||||
assert isinstance(lite._precision_plugin, LiteDoublePrecision)
|
||||
assert isinstance(lite._strategy.cluster_environment, SLURMEnvironment)
|
||||
|
||||
|
||||
def test_lite_convert_custom_plugin():
|
||||
class CustomPrecisionPlugin(PLPrecisionPlugin):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match=escape("You passed an unsupported plugin as input to Lite(plugins=...)")):
|
||||
EmptyLite(plugins=CustomPrecisionPlugin())
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_lite_convert_pl_strategies_deepspeed():
|
||||
lite = EmptyLite(strategy=PLDeepSpeedStrategy(stage=2, initial_scale_power=32, loss_scale_window=500))
|
||||
assert isinstance(lite._strategy, LiteDeepSpeedStrategy)
|
||||
assert lite._strategy.config["zero_optimization"]["stage"] == 2
|
||||
assert lite._strategy.initial_scale_power == 32
|
||||
assert lite._strategy.loss_scale_window == 500
|
Loading…
Reference in New Issue