Add backward-compatibility for LightningLite in PL (#14735)

This commit is contained in:
awaelchli 2022-09-19 16:07:16 +02:00 committed by Adrian Wälchli
parent e3e71670e6
commit c0ff7a1b77
14 changed files with 441 additions and 83 deletions

View File

@ -38,10 +38,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import Accuracy
from lightning_lite.lite import LightningLite # import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite # import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -34,10 +34,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -22,10 +22,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.loops import Loop
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -52,11 +52,11 @@ from lightning_lite.strategies import (
XLAStrategy,
)
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning_lite.utilities import _StrategyType, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
from lightning_lite.utilities.device_parser import determine_root_gpu_device
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
@ -99,8 +99,6 @@ class _Connector:
num_nodes: int = 1,
precision: Union[int, str] = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
gpus: Optional[Union[List[int], str, int]] = None, # deprecated
) -> None:
# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
@ -125,9 +123,7 @@ class _Connector:
precision=precision,
plugins=plugins,
)
self._check_device_config_and_set_final_flags(
devices=devices, num_nodes=num_nodes, gpus=gpus, tpu_cores=tpu_cores
)
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
# 2. Instantiate Accelerator
# handle `auto`, `None` and `gpu`
@ -278,11 +274,7 @@ class _Connector:
self._parallel_devices = self._strategy_flag.parallel_devices
def _check_device_config_and_set_final_flags(
self,
devices: Optional[Union[List[int], str, int]],
num_nodes: int,
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
) -> None:
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
self._devices_flag = devices
@ -298,56 +290,12 @@ class _Connector:
f" using {accelerator_name} accelerator."
)
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(devices, gpus, tpu_cores)
if self._devices_flag == "auto" and self._accelerator_flag is None:
raise ValueError(
f"You passed `devices={devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
)
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
self,
devices: Optional[Union[List[int], str, int]],
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
) -> None:
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
`accelerator_flag`."""
if gpus is not None:
rank_zero_deprecation(
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
)
if tpu_cores is not None:
rank_zero_deprecation(
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
)
self._gpus: Optional[Union[List[int], str, int]] = gpus
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
deprecated_devices_specific_flag = gpus or tpu_cores
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
if devices:
# TODO: improve error message
rank_zero_warn(
f"The flag `devices={devices}` will be ignored, "
f"instead the device specific number {deprecated_devices_specific_flag} will be used"
)
if [(gpus is not None), (tpu_cores is not None)].count(True) > 1:
# TODO: improve error message
rank_zero_warn("more than one device specific flag has been set")
self._devices_flag = deprecated_devices_specific_flag
if self._accelerator_flag is None:
# set accelerator type based on num_processes, gpus, ipus, tpu_cores
if tpu_cores:
self._accelerator_flag = "tpu"
if gpus:
self._accelerator_flag = "cuda"
def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if self._accelerator_flag == "auto":
@ -392,9 +340,6 @@ class _Connector:
self._set_devices_flag_if_auto_passed()
self._gpus = self._devices_flag if not self._gpus else self._gpus
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores
self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)

View File

@ -64,8 +64,6 @@ class LightningLite(ABC):
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
plugins: One or several custom plugins
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
"""
def __init__(
@ -76,8 +74,6 @@ class LightningLite(ABC):
num_nodes: int = 1,
precision: Union[int, str] = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
) -> None:
self._connector = _Connector(
accelerator=accelerator,
@ -86,8 +82,6 @@ class LightningLite(ABC):
num_nodes=num_nodes,
precision=precision,
plugins=plugins,
tpu_cores=tpu_cores,
gpus=gpus,
)
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator

View File

@ -54,7 +54,6 @@ class XLAStrategy(DDPSpawnStrategy):
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
**_: Any,
) -> None:
super().__init__(
accelerator=accelerator,

View File

@ -261,7 +261,9 @@ class EarlyStopping(Callback):
@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
rank = _get_rank(
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
)
if trainer is not None and trainer.world_size <= 1:
rank = None
message = rank_prefixed_message(message, rank)

View File

@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.lite.lite import LightningLite
__all__ = ["LightningLite"]

View File

@ -0,0 +1,308 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import List, Optional, Tuple, Union
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
from lightning_lite.lite import LightningLite as _NewLightningLite
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision
from lightning_lite.plugins import DoublePrecision as LiteDoublePrecision
from lightning_lite.plugins import NativeMixedPrecision as LiteNativeMixedPrecision
from lightning_lite.plugins import Precision as LitePrecision
from lightning_lite.plugins import TPUBf16Precision as LiteTPUBf16Precision
from lightning_lite.plugins import TPUPrecision as LiteTPUPrecision
from lightning_lite.strategies import DataParallelStrategy as LiteDataParallelStrategy
from lightning_lite.strategies import DDPShardedStrategy as LiteDDPShardedStrategy
from lightning_lite.strategies import DDPSpawnShardedStrategy as LiteDDPSpawnShardedStrategy
from lightning_lite.strategies import DDPSpawnStrategy as LiteDDPSpawnStrategy
from lightning_lite.strategies import DDPStrategy as LiteDDPStrategy
from lightning_lite.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
from lightning_lite.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
from lightning_lite.strategies import SingleTPUStrategy as LiteSingleTPUStrategy
from lightning_lite.strategies import Strategy as LiteStrategy
from lightning_lite.strategies import XLAStrategy
from pytorch_lightning.accelerators import Accelerator as PLAccelerator
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin as PLDeepSpeedPrecisionPlugin
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin as PLNativeMixedPrecisionPlugin
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin as PLTPUBf16PrecisionPlugin
from pytorch_lightning.plugins import TPUPrecisionPlugin as PLTPUPrecisionPlugin
from pytorch_lightning.strategies import DataParallelStrategy as PLDataParallelStrategy
from pytorch_lightning.strategies import DDPShardedStrategy as PLDDPShardedStrategy
from pytorch_lightning.strategies import DDPSpawnShardedStrategy as PLDDPSpawnShardedStrategy
from pytorch_lightning.strategies import DDPSpawnStrategy as PLDDPSpawnStrategy
from pytorch_lightning.strategies import DDPStrategy as PLDDPStrategy
from pytorch_lightning.strategies import DeepSpeedStrategy as PLDeepSpeedStrategy
from pytorch_lightning.strategies import SingleDeviceStrategy as PLSingleDeviceStrategy
from pytorch_lightning.strategies import SingleTPUStrategy as PLSingleTPUStrategy
from pytorch_lightning.strategies import Strategy as PLStrategy
from pytorch_lightning.strategies import TPUSpawnStrategy as PLTPUSpawnStrategy
_PL_PLUGIN = Union[PLPrecisionPlugin, ClusterEnvironment, CheckpointIO]
_PL_PLUGIN_INPUT = Union[_PL_PLUGIN, str]
class LightningLite(_NewLightningLite, ABC):
"""Lite accelerates your PyTorch training or inference code with minimal changes required.
- Automatic placement of models and data onto the device.
- Automatic support for mixed and double precision (smaller memory footprint).
- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
(data-parallel training, sharded training, etc.).
- Automated spawning of processes, no launch utilities required.
- Multi-node support.
Args:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
plugins: One or several custom plugins
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
.. deprecated:: v1.8.0
``gpus`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
.. deprecated:: v1.8.0
``tpu_cores`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
"""
def __init__(
self,
accelerator: Optional[Union[str, PLAccelerator]] = None,
strategy: Optional[Union[str, PLStrategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
) -> None:
if gpus is not None or tpu_cores is not None:
devices, accelerator = _convert_deprecated_device_flags(
accelerator=accelerator,
devices=devices,
gpus=gpus,
tpu_cores=tpu_cores,
)
lite_plugins: Optional[Union[_LITE_PLUGIN_INPUT, List[_LITE_PLUGIN_INPUT]]]
if isinstance(plugins, PLPrecisionPlugin):
lite_plugins = _to_lite_precision_plugin(plugins)
elif isinstance(plugins, list):
lite_plugins = [
_to_lite_precision_plugin(plugin) if isinstance(plugin, PLPrecisionPlugin) else plugin
for plugin in plugins
]
else:
lite_plugins = plugins
super().__init__(
accelerator=accelerator,
strategy=(_to_lite_strategy(strategy) if isinstance(strategy, PLStrategy) else strategy),
devices=devices,
num_nodes=num_nodes,
precision=precision,
plugins=lite_plugins,
)
def _convert_deprecated_device_flags(
accelerator: Optional[Union[str, PLAccelerator]],
devices: Optional[Union[List[int], str, int]],
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[Union[List[int], str, int]], Optional[Union[str, PLAccelerator]]]:
"""Emit deprecation warnings for gpus and tpu_cores and translate them into the new accelerator and devices.
Similar implementation as in ``pytorch_lightning.trainer.connectors.accelerator_connector``.
"""
if gpus is not None:
rank_zero_deprecation(
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.8.0 and will be removed"
f" in v1.10.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
)
if tpu_cores is not None:
rank_zero_deprecation(
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.8.0 and will be removed"
f" in v1.10.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
)
deprecated_devices_specific_flag = gpus or tpu_cores
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
if devices:
rank_zero_warn(
f"The option `devices={devices}` will be ignored and the device specific number"
f"{deprecated_devices_specific_flag} will be used instead."
)
if gpus is not None and tpu_cores is not None:
rank_zero_warn(
f"Both `Lite(gpus={gpus!r}, tpu_cores={tpu_cores!r})` were specified. Please choose only one of"
" the two."
)
if accelerator is None:
if tpu_cores:
accelerator = "tpu"
if gpus:
accelerator = "cuda"
return deprecated_devices_specific_flag, accelerator
def _to_lite_strategy(strategy: PLStrategy) -> LiteStrategy:
"""Re-instantiates a PL-Strategy as the corresponding Lite-Strategy."""
if type(strategy) is PLDDPStrategy:
return LiteDDPStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)
if type(strategy) is PLDDPSpawnStrategy:
return LiteDDPSpawnStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)
if type(strategy) is PLTPUSpawnStrategy:
return XLAStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
)
if type(strategy) is PLDeepSpeedStrategy:
return LiteDeepSpeedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
config=strategy.config,
remote_device=strategy.remote_device,
load_full_weights=strategy.load_full_weights,
loss_scale=strategy.loss_scale,
initial_scale_power=strategy.initial_scale_power,
loss_scale_window=strategy.loss_scale_window,
hysteresis=strategy.hysteresis,
min_loss_scale=strategy.min_loss_scale,
)
if type(strategy) is PLDataParallelStrategy:
return LiteDataParallelStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
)
if type(strategy) is PLDDPShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)
if type(strategy) is PLDDPSpawnShardedStrategy:
return LiteDDPSpawnShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)
if type(strategy) is PLSingleDeviceStrategy:
return LiteSingleDeviceStrategy(
device=strategy.root_device,
accelerator=strategy.accelerator,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
)
if type(strategy) is PLSingleTPUStrategy:
return LiteSingleTPUStrategy(
device=strategy.root_device.index,
accelerator=strategy.accelerator,
checkpoint_io=strategy.checkpoint_io,
precision_plugin=_to_lite_precision_plugin(strategy.precision_plugin),
)
def _to_lite_precision_plugin(plugin: Optional[PLPrecisionPlugin]) -> LitePrecision:
"""Re-instantiates a PL-PrecisionPlugin as the corresponding Lite-Precision plugin."""
if type(plugin) is PLPrecisionPlugin:
return LitePrecision()
if type(plugin) is PLNativeMixedPrecisionPlugin:
return LiteNativeMixedPrecision(precision=plugin.precision, device=plugin.device, scaler=plugin.scaler)
if type(plugin) is PLDoublePrecisionPlugin:
return LiteDoublePrecision()
if type(plugin) is PLDeepSpeedPrecisionPlugin:
return LiteDeepSpeedPrecision(precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level)
if type(plugin) is PLTPUPrecisionPlugin:
return LiteTPUPrecision()
if type(plugin) is PLTPUBf16PrecisionPlugin:
return LiteTPUBf16Precision()
# No backward compatibility for custom plugins / subclasses, as we can't re-instantiate these plugins
raise TypeError(
"You passed an unsupported plugin as input to Lite(plugins=...) or to a strategy. If you built a custom plugin,"
" please change it to subclass the `lightning_lite.plugins.precision.Precision` class. Otherwise, please open"
" an issue on the Lightning GitHub repository with your use case."
)

View File

@ -16,9 +16,8 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy
PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN_INPUT = Union[PLUGIN, str]
__all__ = [

View File

@ -262,21 +262,11 @@ def test_accelerator_cpu(*_):
connector = _Connector(accelerator="cpu")
assert isinstance(connector.accelerator, CPUAccelerator)
with pytest.raises(
RuntimeError,
match="CUDAAccelerator can not run on your system since the accelerator is not available",
):
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"):
_Connector(gpus=1)
with pytest.raises(
RuntimeError,
match="CUDAAccelerator can not run on your system since the accelerator is not available.",
):
_Connector(accelerator="cuda")
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"):
_Connector(accelerator="cpu", gpus=1)
_Connector(accelerator="cuda", devices=1)
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.10.0."""
from re import escape
from unittest import mock
import numpy
@ -19,10 +20,13 @@ import pytest
import torch
from torch.utils.data import DataLoader
from lightning_lite.accelerators import CUDAAccelerator as LiteCUDAAccelerator
from lightning_lite.accelerators import TPUAccelerator as LiteTPUAccelerator
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
@ -260,3 +264,24 @@ def test_v1_10_deprecated_seed_utilities():
def test_v1_10_deprecated_accelerator_setup_environment_method():
with pytest.deprecated_call(match="`Accelerator.setup_environment` has been deprecated in deprecated in v1.8.0"):
CPUAccelerator().setup_environment(torch.device("cpu"))
class EmptyLite(LightningLite):
def run(self):
pass
def test_lite_convert_deprecated_gpus_argument(cuda_count_2):
with pytest.deprecated_call(match=escape("Setting `Lite(gpus=2)` is deprecated in v1.8.0")):
lite = EmptyLite(gpus=2)
assert isinstance(lite._accelerator, LiteCUDAAccelerator)
assert lite._connector._parallel_devices == [torch.device("cuda", 0), torch.device("cuda", 1)]
@RunIf(skip_windows=True)
@mock.patch("lightning_lite.accelerators.TPUAccelerator.is_available", return_value=True)
def test_lite_convert_deprecated_tpus_argument(*_):
with pytest.deprecated_call(match=escape("Setting `Lite(tpu_cores=8)` is deprecated in v1.8.0")):
lite = EmptyLite(tpu_cores=8)
assert isinstance(lite._accelerator, LiteTPUAccelerator)
assert lite._connector._parallel_devices == list(range(8))

View File

View File

@ -0,0 +1,79 @@
from re import escape
import pytest
from lightning_lite.accelerators import CPUAccelerator as LiteCPUAccelerator
from lightning_lite.plugins import DoublePrecision as LiteDoublePrecision
from lightning_lite.plugins import Precision as LitePrecision
from lightning_lite.plugins.environments import SLURMEnvironment
from lightning_lite.strategies import DDPStrategy as LiteDDPStrategy
from lightning_lite.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
from lightning_lite.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
from pytorch_lightning.accelerators import CUDAAccelerator as PLCUDAAccelerator
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
from pytorch_lightning.strategies import DDPStrategy as PLDDPStrategy
from pytorch_lightning.strategies import DeepSpeedStrategy as PLDeepSpeedStrategy
from tests_pytorch.helpers.runif import RunIf
class EmptyLite(LightningLite):
def run(self):
pass
def test_lite_convert_pl_strategies_and_plugins(cuda_count_2):
"""Tests a few examples of passing PL-accelerators/strategies/plugins to the soon deprecated PL version of
Lightning Lite for backward compatibility.
Not all possible combinations of input arguments are tested.
"""
# defaults
lite = EmptyLite()
assert isinstance(lite._accelerator, LiteCPUAccelerator)
assert isinstance(lite._precision_plugin, LitePrecision)
assert isinstance(lite._strategy, LiteSingleDeviceStrategy)
# accelerator and strategy passed separately
lite = EmptyLite(accelerator=PLCUDAAccelerator(), strategy=PLDDPStrategy())
assert isinstance(lite._accelerator, PLCUDAAccelerator)
assert isinstance(lite._precision_plugin, LitePrecision)
assert isinstance(lite._strategy, LiteDDPStrategy)
# accelerator passed to strategy
lite = EmptyLite(strategy=PLDDPStrategy(accelerator=PLCUDAAccelerator()))
assert isinstance(lite._accelerator, PLCUDAAccelerator)
assert isinstance(lite._strategy, LiteDDPStrategy)
# kwargs passed to strategy
lite = EmptyLite(strategy=PLDDPStrategy(find_unused_parameters=False))
assert isinstance(lite._strategy, LiteDDPStrategy)
assert lite._strategy._ddp_kwargs == dict(find_unused_parameters=False)
# plugins = instance
lite = EmptyLite(plugins=PLDoublePrecisionPlugin())
assert isinstance(lite._precision_plugin, LiteDoublePrecision)
# plugins = list
lite = EmptyLite(plugins=[PLDoublePrecisionPlugin(), SLURMEnvironment()], devices=2)
assert isinstance(lite._precision_plugin, LiteDoublePrecision)
assert isinstance(lite._strategy.cluster_environment, SLURMEnvironment)
def test_lite_convert_custom_plugin():
class CustomPrecisionPlugin(PLPrecisionPlugin):
pass
with pytest.raises(TypeError, match=escape("You passed an unsupported plugin as input to Lite(plugins=...)")):
EmptyLite(plugins=CustomPrecisionPlugin())
@RunIf(deepspeed=True)
def test_lite_convert_pl_strategies_deepspeed():
lite = EmptyLite(strategy=PLDeepSpeedStrategy(stage=2, initial_scale_power=32, loss_scale_window=500))
assert isinstance(lite._strategy, LiteDeepSpeedStrategy)
assert lite._strategy.config["zero_optimization"]["stage"] == 2
assert lite._strategy.initial_scale_power == 32
assert lite._strategy.loss_scale_window == 500