From e0b4bb2ea34ea2ade517da2b9a4cdbb7d97e3de0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 25 Nov 2021 21:11:03 +0530 Subject: [PATCH] Deprecate `DeviceType` in favor of `_AcceleratorType` (#10503) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 + .../callbacks/gpu_stats_monitor.py | 4 +- .../callbacks/xla_stats_monitor.py | 4 +- pytorch_lightning/lite/lite.py | 10 +- .../loops/optimization/optimizer_loop.py | 4 +- .../connectors/accelerator_connector.py | 98 +++++++++---------- .../logger_connector/logger_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 20 ++-- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/enums.py | 40 ++++++-- pytorch_lightning/utilities/model_summary.py | 8 +- .../test_accelerator_connector.py | 4 +- tests/accelerators/test_ipu.py | 4 +- tests/deprecated_api/test_remove_1-8.py | 8 +- tests/models/test_tpu.py | 4 +- tests/trainer/test_trainer.py | 98 ++++++++++--------- tests/utilities/test_enums.py | 14 +-- 17 files changed, 185 insertions(+), 144 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a87b3b95f0..43c0c6ab14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Deprecated `DeviceType` in favor of `_AcceleratorType` ([#10503](https://github.com/PyTorchLightning/pytorch-lightning/pull/10503)) + + - Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 7ee6771056..088c8e6500 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -29,7 +29,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import DeviceType, rank_zero_deprecation, rank_zero_only +from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -126,7 +126,7 @@ class GPUStatsMonitor(Callback): if not trainer.logger: raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.") - if trainer._device_type != DeviceType.GPU: + if trainer._device_type != _AcceleratorType.GPU: raise MisconfigurationException( "You are using GPUStatsMonitor but are not running on GPU" f" since gpus attribute in Trainer is set to {trainer.gpus}." diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 20d3f1b8ba..9c4f09c08a 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -21,7 +21,7 @@ Monitor and logs XLA stats during training. import time from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType, rank_zero_deprecation, rank_zero_info +from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: @@ -70,7 +70,7 @@ class XLAStatsMonitor(Callback): if not trainer.logger: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") - if trainer._device_type != DeviceType.TPU: + if trainer._device_type != _AcceleratorType.TPU: raise MisconfigurationException( "You are using XLAStatsMonitor but are not running on TPU" f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}." diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index b2adeeac4b..9073f5dd54 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -28,7 +28,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, @@ -448,11 +448,11 @@ class LightningLite(ABC): ) @staticmethod - def _supported_device_types() -> Sequence[DeviceType]: + def _supported_device_types() -> Sequence[_AcceleratorType]: return ( - DeviceType.CPU, - DeviceType.GPU, - DeviceType.TPU, + _AcceleratorType.CPU, + _AcceleratorType.GPU, + _AcceleratorType.TPU, ) @staticmethod diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 7050ac75de..b6bc1c3c25 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -30,7 +30,7 @@ from pytorch_lightning.loops.utilities import ( ) from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.utilities import AMPType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -378,7 +378,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), + on_tpu=(self.trainer._device_type == _AcceleratorType.TPU and _TPU_AVAILABLE), using_native_amp=(self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE), using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c95d46e77b..ba1166a019 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -61,10 +61,10 @@ from pytorch_lightning.plugins.environments import ( TorchElasticEnvironment, ) from pytorch_lightning.utilities import ( + _AcceleratorType, _StrategyType, AMPType, device_parser, - DeviceType, rank_zero_deprecation, rank_zero_info, rank_zero_warn, @@ -106,7 +106,7 @@ class AcceleratorConnector: plugins, ): # initialization - self._device_type = DeviceType.CPU + self._device_type = _AcceleratorType.CPU self._distrib_type = None self._accelerator_type = None @@ -199,32 +199,32 @@ class AcceleratorConnector: def select_accelerator_type(self) -> None: if self.distributed_backend == "auto": if self.has_tpu: - self._accelerator_type = DeviceType.TPU + self._accelerator_type = _AcceleratorType.TPU elif self.has_ipu: - self._accelerator_type = DeviceType.IPU + self._accelerator_type = _AcceleratorType.IPU elif self.has_gpu: - self._accelerator_type = DeviceType.GPU + self._accelerator_type = _AcceleratorType.GPU else: self._set_devices_to_cpu_num_processes() - self._accelerator_type = DeviceType.CPU - elif self.distributed_backend == DeviceType.TPU: + self._accelerator_type = _AcceleratorType.CPU + elif self.distributed_backend == _AcceleratorType.TPU: if not self.has_tpu: msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`" raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.") - self._accelerator_type = DeviceType.TPU - elif self.distributed_backend == DeviceType.IPU: + self._accelerator_type = _AcceleratorType.TPU + elif self.distributed_backend == _AcceleratorType.IPU: if not self.has_ipu: msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`" raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.") - self._accelerator_type = DeviceType.IPU - elif self.distributed_backend == DeviceType.GPU: + self._accelerator_type = _AcceleratorType.IPU + elif self.distributed_backend == _AcceleratorType.GPU: if not self.has_gpu: msg = "you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available" raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.") - self._accelerator_type = DeviceType.GPU - elif self.distributed_backend == DeviceType.CPU: + self._accelerator_type = _AcceleratorType.GPU + elif self.distributed_backend == _AcceleratorType.CPU: self._set_devices_to_cpu_num_processes() - self._accelerator_type = DeviceType.CPU + self._accelerator_type = _AcceleratorType.CPU if self.distributed_backend in self.accelerator_types: self.distributed_backend = None @@ -250,29 +250,29 @@ class AcceleratorConnector: if self.devices is None: return devices_warning = f"The flag `devices={self.devices}` will be ignored, as you have set" - if self.distributed_backend in ("auto", DeviceType.TPU): + if self.distributed_backend in ("auto", _AcceleratorType.TPU): if self.tpu_cores is not None: rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`") - elif self.distributed_backend in ("auto", DeviceType.IPU): + elif self.distributed_backend in ("auto", _AcceleratorType.IPU): if self.ipus is not None: rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`") - elif self.distributed_backend in ("auto", DeviceType.GPU): + elif self.distributed_backend in ("auto", _AcceleratorType.GPU): if self.gpus is not None: rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`") - elif self.distributed_backend in ("auto", DeviceType.CPU): + elif self.distributed_backend in ("auto", _AcceleratorType.CPU): if self.num_processes != 1: rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`") def _set_devices_if_none(self) -> None: if self.devices is not None: return - if self._accelerator_type == DeviceType.TPU: + if self._accelerator_type == _AcceleratorType.TPU: self.devices = self.tpu_cores - elif self._accelerator_type == DeviceType.IPU: + elif self._accelerator_type == _AcceleratorType.IPU: self.devices = self.ipus - elif self._accelerator_type == DeviceType.GPU: + elif self._accelerator_type == _AcceleratorType.GPU: self.devices = self.gpus - elif self._accelerator_type == DeviceType.CPU: + elif self._accelerator_type == _AcceleratorType.CPU: self.devices = self.num_processes def _handle_accelerator_and_strategy(self) -> None: @@ -386,7 +386,7 @@ class AcceleratorConnector: @property def accelerator_types(self) -> List[str]: - return ["auto"] + list(DeviceType) + return ["auto"] + list(_AcceleratorType) @property def precision_plugin(self) -> PrecisionPlugin: @@ -424,7 +424,7 @@ class AcceleratorConnector: @property def use_cpu(self) -> bool: - return self._accelerator_type == DeviceType.CPU + return self._accelerator_type == _AcceleratorType.CPU @property def has_gpu(self) -> bool: @@ -433,11 +433,11 @@ class AcceleratorConnector: gpus = self.parallel_device_ids if gpus is not None and len(gpus) > 0: return True - return self._map_devices_to_accelerator(DeviceType.GPU) + return self._map_devices_to_accelerator(_AcceleratorType.GPU) @property def use_gpu(self) -> bool: - return self._accelerator_type == DeviceType.GPU and self.has_gpu + return self._accelerator_type == _AcceleratorType.GPU and self.has_gpu @property def has_tpu(self) -> bool: @@ -445,11 +445,11 @@ class AcceleratorConnector: # `tpu_cores` to Trainer for training. if self.tpu_cores is not None: return True - return self._map_devices_to_accelerator(DeviceType.TPU) + return self._map_devices_to_accelerator(_AcceleratorType.TPU) @property def use_tpu(self) -> bool: - return self._accelerator_type == DeviceType.TPU and self.has_tpu + return self._accelerator_type == _AcceleratorType.TPU and self.has_tpu @property def tpu_id(self) -> Optional[int]: @@ -463,36 +463,36 @@ class AcceleratorConnector: # `ipus` to Trainer for training. if self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin): return True - return self._map_devices_to_accelerator(DeviceType.IPU) + return self._map_devices_to_accelerator(_AcceleratorType.IPU) @property def use_ipu(self) -> bool: - return self._accelerator_type == DeviceType.IPU and self.has_ipu + return self._accelerator_type == _AcceleratorType.IPU and self.has_ipu def _set_devices_to_cpu_num_processes(self) -> None: if self.num_processes == 1: - self._map_devices_to_accelerator(DeviceType.CPU) + self._map_devices_to_accelerator(_AcceleratorType.CPU) def _map_devices_to_accelerator(self, accelerator: str) -> bool: if self.devices is None: return False - if accelerator == DeviceType.TPU and _TPU_AVAILABLE: + if accelerator == _AcceleratorType.TPU and _TPU_AVAILABLE: if self.devices == "auto": self.devices = TPUAccelerator.auto_device_count() self.tpu_cores = device_parser.parse_tpu_cores(self.devices) return True - if accelerator == DeviceType.IPU and _IPU_AVAILABLE: + if accelerator == _AcceleratorType.IPU and _IPU_AVAILABLE: if self.devices == "auto": self.devices = IPUAccelerator.auto_device_count() self.ipus = self.devices return True - if accelerator == DeviceType.GPU and torch.cuda.is_available(): + if accelerator == _AcceleratorType.GPU and torch.cuda.is_available(): if self.devices == "auto": self.devices = GPUAccelerator.auto_device_count() self.gpus = self.devices self.parallel_device_ids = device_parser.parse_gpu_ids(self.devices) return True - if accelerator == DeviceType.CPU: + if accelerator == _AcceleratorType.CPU: if self.devices == "auto": self.devices = CPUAccelerator.auto_device_count() if not isinstance(self.devices, int): @@ -829,7 +829,7 @@ class AcceleratorConnector: if isinstance(self.distributed_backend, Accelerator): return - is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU + is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == _AcceleratorType.CPU _use_cpu = is_cpu_accelerator_type or self.distributed_backend and "cpu" in self.distributed_backend if self.distributed_backend is None: @@ -867,16 +867,16 @@ class AcceleratorConnector: self.num_processes = os.cpu_count() # special case with TPUs elif self.has_tpu and not _use_cpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = _StrategyType.TPU_SPAWN elif self.has_ipu and not _use_cpu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = _StrategyType(self.distributed_backend) if self.num_gpus > 0 and not _use_cpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU _gpu_distrib_types = (_StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2) # DP and DDP2 cannot run without GPU @@ -896,13 +896,13 @@ class AcceleratorConnector: self.check_interactive_compatibility() # for DDP overwrite nb processes by requested GPUs - if self._device_type == DeviceType.GPU and self._distrib_type in ( + if self._device_type == _AcceleratorType.GPU and self._distrib_type in ( _StrategyType.DDP, _StrategyType.DDP_SPAWN, ): self.num_processes = self.num_gpus - if self._device_type == DeviceType.GPU and self._distrib_type == _StrategyType.DDP2: + if self._device_type == _AcceleratorType.GPU and self._distrib_type == _StrategyType.DDP2: self.num_processes = self.num_nodes # Horovod is an extra case... @@ -965,8 +965,8 @@ class AcceleratorConnector: def update_device_type_if_ipu_plugin(self) -> None: # This allows the poptorch.Options that are passed into the IPUPlugin to be the source of truth, # which gives users the flexibility to not have to pass `ipus` flag directly to Trainer - if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU: - self._device_type = DeviceType.IPU + if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != _AcceleratorType.IPU: + self._device_type = _AcceleratorType.IPU def update_device_type_if_training_type_plugin_passed(self) -> None: if isinstance(self.strategy, TrainingTypePlugin) or any( @@ -974,18 +974,18 @@ class AcceleratorConnector: ): if self._accelerator_type is not None: if self.use_ipu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.use_tpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU elif self.use_gpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU else: if self.has_ipu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.has_tpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU elif self.has_gpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU def _set_distrib_type_if_training_type_plugin_passed(self): # This is required as when `TrainingTypePlugin` instance is passed to either `strategy` diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index b98f13138b..ecd32f11df 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -21,7 +21,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, Ten from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType, memory +from pytorch_lightning.utilities import _AcceleratorType, memory from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -329,7 +329,7 @@ class LoggerConnector: .. deprecated:: v1.5 Will be removed in v1.7. """ - if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + if self.trainer._device_type == _AcceleratorType.GPU and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) self._gpus_metrics.update(mem_map) return self._gpus_metrics diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 667b57fd1c..18f13a75bf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -63,11 +63,11 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerSta from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( + _AcceleratorType, _IPU_AVAILABLE, _StrategyType, _TPU_AVAILABLE, device_parser, - DeviceType, GradClipAlgorithmType, parsing, rank_zero_deprecation, @@ -1519,26 +1519,32 @@ class Trainer( self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) def _log_device_info(self) -> None: - rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}") + rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}") - num_tpu_cores = self.tpu_cores if self.tpu_cores is not None and self._device_type == DeviceType.TPU else 0 + num_tpu_cores = ( + self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0 + ) rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") num_ipus = self.ipus if self.ipus is not None else 0 rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") - if torch.cuda.is_available() and self._device_type != DeviceType.GPU: + if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU: rank_zero_warn( "GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`." ) - if _TPU_AVAILABLE and self._device_type != DeviceType.TPU: + if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU: rank_zero_warn( "TPU available but not used. Set the `tpu_cores` flag in your trainer" " `Trainer(tpu_cores=8)` or script `--tpu_cores=8`." ) - if _IPU_AVAILABLE and self._device_type != DeviceType.IPU and not isinstance(self.accelerator, IPUAccelerator): + if ( + _IPU_AVAILABLE + and self._device_type != _AcceleratorType.IPU + and not isinstance(self.accelerator, IPUAccelerator) + ): rank_zero_warn( "IPU available but not used. Set the `ipus` flag in your trainer" " `Trainer(ipus=8)` or script `--ipus=8`." @@ -1595,7 +1601,7 @@ class Trainer( return self._accelerator_connector._distrib_type @property - def _device_type(self) -> DeviceType: + def _device_type(self) -> _AcceleratorType: return self._accelerator_connector._device_type @property diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 22164908a3..48a18db121 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -18,9 +18,9 @@ import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401 from pytorch_lightning.utilities.enums import ( # noqa: F401 + _AcceleratorType, _StrategyType, AMPType, - DeviceType, DistributedType, GradClipAlgorithmType, LightningEnum, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 1d7a6e3fa5..51eb02c018 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -143,17 +143,12 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta): ) -class DeviceType(LightningEnum): - """Define Device type by its nature - acceleatrors. +class DeviceType(LightningEnum, metaclass=_OnAccessEnumMeta): + """Define Device type by its nature - accelerators. - >>> DeviceType.CPU == DeviceType.from_str('cpu') - True - >>> # you can match the type with string - >>> DeviceType.GPU == 'GPU' - True - >>> # which is case invariant - >>> DeviceType.TPU in ('tpu', 'CPU') - True + Deprecated since v1.6.0 and will be removed in v1.8.0. + + Use `_AcceleratorType` instead. """ CPU = "CPU" @@ -161,6 +156,12 @@ class DeviceType(LightningEnum): IPU = "IPU" TPU = "TPU" + def deprecate(self) -> None: + rank_zero_deprecation( + "`DeviceType` Enum has been deprecated in v1.6 and will be removed in v1.8." + " Use the string value `{self.value!r}` instead." + ) + class GradClipAlgorithmType(LightningEnum): """Define gradient_clip_algorithm types - training-tricks. @@ -260,6 +261,25 @@ class _StrategyType(LightningEnum): return self in _StrategyType.interactive_compatible_types() +class _AcceleratorType(LightningEnum): + """Define Accelerator type by its nature. + + >>> _AcceleratorType.CPU == _AcceleratorType.from_str('cpu') + True + >>> # you can match the type with string + >>> _AcceleratorType.GPU == 'GPU' + True + >>> # which is case invariant + >>> _AcceleratorType.TPU in ('tpu', 'CPU') + True + """ + + CPU = "CPU" + GPU = "GPU" + IPU = "IPU" + TPU = "TPU" + + class _FaultTolerantMode(LightningEnum): DISABLED = "disabled" diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index bab6da5368..37ff258436 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -23,7 +23,7 @@ from torch import Tensor from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl -from pytorch_lightning.utilities import AMPType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, AMPType from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache @@ -261,7 +261,11 @@ class ModelSummary: input_ = model.example_input_array input_ = model._apply_batch_transfer_handler(input_) - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: + if ( + trainer is not None + and trainer.amp_backend == AMPType.NATIVE + and trainer._device_type != _AcceleratorType.TPU + ): model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index a9c9c50d80..c95c7dc517 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -43,7 +43,7 @@ from pytorch_lightning.plugins.environments import ( SLURMEnvironment, TorchElasticEnvironment, ) -from pytorch_lightning.utilities import _StrategyType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -729,7 +729,7 @@ def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin): trainer = Trainer(strategy=plugin(), gpus=2) assert isinstance(trainer.training_type_plugin, plugin) - assert trainer._device_type == DeviceType.GPU + assert trainer._device_type == _AcceleratorType.GPU assert isinstance(trainer.accelerator, GPUAccelerator) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index be2e597c9a..524e122478 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _IPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -571,7 +571,7 @@ def test_device_type_when_training_plugin_ipu_passed(tmpdir): trainer = Trainer(strategy=IPUPlugin(), ipus=8) assert isinstance(trainer.training_type_plugin, IPUPlugin) - assert trainer._device_type == DeviceType.IPU + assert trainer._device_type == _AcceleratorType.IPU assert isinstance(trainer.accelerator, IPUAccelerator) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index f668f63b9f..0c32773b56 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -14,10 +14,16 @@ """Test deprecated functionality which will be removed in v1.8.0.""" import pytest -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import DeviceType, DistributedType def test_v1_8_0_deprecated_distributed_type_enum(): with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): _ = DistributedType.DDP + + +def test_v1_8_0_deprecated_device_type_enum(): + + with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): + _ = DeviceType.CPU diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index bb4c1d017d..ea8d430918 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,7 +26,7 @@ from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync -from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -474,5 +474,5 @@ def test_device_type_when_training_plugin_tpu_passed(tmpdir): trainer = Trainer(strategy=TPUSpawnPlugin(), tpu_cores=8) assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) - assert trainer._device_type == DeviceType.TPU + assert trainer._device_type == _AcceleratorType.TPU assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d39d83ec3..6004d4540a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -48,7 +48,7 @@ from pytorch_lightning.plugins import ( DDPSpawnShardedPlugin, ) from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _StrategyType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -1149,75 +1149,75 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): [ ( dict(accelerator=None, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="dp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator=None, gpus=1), - dict(_distrib_type=None, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="dp", gpus=1), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp", gpus=1), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=1), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=1), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator=None, gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="dp", gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp", gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="ddp2", gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="dp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ], ) @@ -2091,118 +2091,118 @@ def test_detect_anomaly_nan(tmpdir): [ ( dict(strategy=None, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="dp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp2", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy=None, gpus=1), - dict(_distrib_type=None, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="dp", gpus=1), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp", gpus=1), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp_spawn", gpus=1), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp2", gpus=1), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy=None, gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="dp", gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp", gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="ddp2", gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="dp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=1, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp_fully_sharded", gpus=1), dict( _distrib_type=_StrategyType.DDP_FULLY_SHARDED, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1, ), ), ( dict(strategy=DDPSpawnPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPSpawnPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDP2Plugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DataParallelPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPFullyShardedPlugin(), gpus=2), dict( _distrib_type=_StrategyType.DDP_FULLY_SHARDED, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1, ), @@ -2211,14 +2211,16 @@ def test_detect_anomaly_nan(tmpdir): dict(strategy=DDPSpawnShardedPlugin(), gpus=2), dict( _distrib_type=_StrategyType.DDP_SHARDED_SPAWN, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1, ), ), ( dict(strategy=DDPShardedPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict( + _distrib_type=_StrategyType.DDP_SHARDED, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1 + ), ), ], ) diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index 4f902e2238..99158e2e83 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -13,17 +13,17 @@ # limitations under the License. import pytest -from pytorch_lightning.utilities.enums import DeviceType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType +from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType def test_consistency(): - assert DeviceType.TPU not in ("GPU", "CPU") - assert DeviceType.TPU in ("TPU", "CPU") - assert DeviceType.TPU in ("tpu", "CPU") - assert DeviceType.TPU not in {"GPU", "CPU"} + assert _AcceleratorType.TPU not in ("GPU", "CPU") + assert _AcceleratorType.TPU in ("TPU", "CPU") + assert _AcceleratorType.TPU in ("tpu", "CPU") + assert _AcceleratorType.TPU not in {"GPU", "CPU"} # hash cannot be case invariant - assert DeviceType.TPU not in {"TPU", "CPU"} - assert DeviceType.TPU in {"tpu", "CPU"} + assert _AcceleratorType.TPU not in {"TPU", "CPU"} + assert _AcceleratorType.TPU in {"tpu", "CPU"} def test_precision_supported_types():