From 20d19d2f5728f7049272f2db77a9748ff4cf5ccd Mon Sep 17 00:00:00 2001 From: Alan Chu <30797645+chualanagit@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:42:14 -0800 Subject: [PATCH] Remove `List[int]` as input type for Trainer when `accelerator="cpu"` (#20399) Co-authored-by: Alan Chu Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/accelerators/cpu.py | 8 ++++---- src/lightning/pytorch/accelerators/cpu.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 1bcec1b2ac..0334210ecd 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -39,13 +39,13 @@ class CPUAccelerator(Accelerator): @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -72,12 +72,12 @@ class CPUAccelerator(Accelerator): ) -def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int: +def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int: """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: - cpu_cores: An int > 0. + cpu_cores: An int > 0 or a string that can be converted to an int > 0. Returns: An int representing the number of processes diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363..a85a959ab6 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -48,13 +48,13 @@ class CPUAccelerator(Accelerator): @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices