diff --git a/pyproject.toml b/pyproject.toml index 05c641dc30..bd151cb468 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ warn_no_return = "False" # the list can be generated with: # mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.accelerators.gpu", "pytorch_lightning.callbacks.finetuning", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 40c9a3c2b9..75c55fdf5f 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from __future__ import annotations + +from typing import Any import torch from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _DEVICE class CPUAccelerator(Accelerator): @@ -28,10 +31,10 @@ class CPUAccelerator(Accelerator): MisconfigurationException: If the selected device is not CPU. """ - if "cpu" not in str(root_device): + if root_device.type != "cpu": raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """CPU device stats aren't supported yet.""" return {} diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index f3d8680c79..3ccf2e4a7f 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import shutil import subprocess -from typing import Any, Dict, List, Union +from typing import Any import torch @@ -23,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.types import _DEVICE _log = logging.getLogger(__name__) @@ -36,11 +39,11 @@ class GPUAccelerator(Accelerator): MisconfigurationException: If the selected device is not GPU. """ - if "cuda" not in str(root_device): - raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") + if root_device.type != "cuda": + raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") torch.cuda.set_device(root_device) - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: # TODO refactor input from trainer to local_rank @four4fish self.set_nvidia_flags(trainer.local_rank) # clear cache before training @@ -54,7 +57,7 @@ class GPUAccelerator(Accelerator): devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -77,7 +80,7 @@ class GPUAccelerator(Accelerator): return torch.cuda.device_count() -def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: +def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: @@ -106,7 +109,8 @@ def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: gpu_stat_keys = [k for k, _ in gpu_stat_metrics] gpu_query = ",".join(gpu_stat_keys) - gpu_id = _get_gpu_id(device.index) + index = torch._utils._get_device_index(device) + gpu_id = _get_gpu_id(index) result = subprocess.run( [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], encoding="utf-8", @@ -130,5 +134,5 @@ def _get_gpu_id(device_id: int) -> str: """Get the unmasked real GPU IDs.""" # All devices if `CUDA_VISIBLE_DEVICES` unset default = ",".join(str(i) for i in range(torch.cuda.device_count())) - cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") return cuda_visible_devices[device_id].strip() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index bb6c1b097e..bccbfa13fa 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from __future__ import annotations + +from typing import Any import torch @@ -20,6 +22,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.types import _DEVICE class SingleDeviceStrategy(Strategy): @@ -27,13 +30,13 @@ class SingleDeviceStrategy(Strategy): def __init__( self, - device: torch.device, - accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + device: _DEVICE, + accelerator: pl.accelerators.accelerator.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) - self.device: torch.device = device + self._root_device = torch.device(device) self.global_rank = 0 self.local_rank = 0 self.world_size = 1 @@ -46,7 +49,7 @@ class SingleDeviceStrategy(Strategy): def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() - def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. @@ -60,18 +63,18 @@ class SingleDeviceStrategy(Strategy): """ return tensor - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes.""" return tensor @property def root_device(self) -> torch.device: - return self.device + return self._root_device def model_to_device(self) -> None: self.model.to(self.root_device) - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: self.model_to_device() super().setup(trainer) diff --git a/pytorch_lightning/strategies/single_tpu.py b/pytorch_lightning/strategies/single_tpu.py index 4895879ffa..8465656f03 100644 --- a/pytorch_lightning/strategies/single_tpu.py +++ b/pytorch_lightning/strategies/single_tpu.py @@ -36,10 +36,12 @@ class SingleTPUStrategy(SingleDeviceStrategy): precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): - device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() super().__init__( - accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + accelerator=accelerator, + device=xm.xla_device(device), + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.debug = debug @@ -60,9 +62,6 @@ class SingleTPUStrategy(SingleDeviceStrategy): super().setup(trainer) - if isinstance(self.device, int): - self.device = xm.xla_device(self.device) - if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7b1c07b0d6..d476bc5f0c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -752,7 +752,7 @@ class AcceleratorConnector: plugin = IPUStrategy(parallel_devices=self.parallel_devices) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) - plugin = SingleDeviceStrategy(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu")) + plugin = SingleDeviceStrategy(device=single_gpu_ordinal if self.use_gpu else "cpu") return plugin def resolve_strategy(self, training_type: Strategy) -> Strategy: diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 1d5cd27226..3e12629a94 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -44,6 +44,7 @@ TRAIN_DATALOADERS = Union[ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] +_DEVICE = Union[torch.device, str, int] @runtime_checkable diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 52c57a1f70..640d8744f2 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -326,7 +326,7 @@ def test_v1_8_0_deprecated_single_device_plugin_class(): " Use `.*SingleDeviceStrategy` instead." ) ): - SingleDevicePlugin(Mock()) + SingleDevicePlugin("cpu") @RunIf(tpu=True) diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 0ce073a34d..7a1352804b 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -44,10 +44,9 @@ def test_checkpoint_plugin_called(tmpdir): ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() - device = torch.device("cpu") trainer = Trainer( default_root_dir=tmpdir, - strategy=SingleDeviceStrategy(device, checkpoint_io=checkpoint_plugin), + strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin), callbacks=ck, max_epochs=2, ) @@ -63,10 +62,9 @@ def test_checkpoint_plugin_called(tmpdir): ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() - device = torch.device("cpu") trainer = Trainer( default_root_dir=tmpdir, - strategy=SingleDeviceStrategy(device), + strategy=SingleDeviceStrategy("cpu"), plugins=[checkpoint_plugin], callbacks=ck, max_epochs=2,