diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 1bcec1b2ac..0334210ecd 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -39,13 +39,13 @@ class CPUAccelerator(Accelerator): @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -72,12 +72,12 @@ class CPUAccelerator(Accelerator): ) -def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int: +def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int: """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: - cpu_cores: An int > 0. + cpu_cores: An int > 0 or a string that can be converted to an int > 0. Returns: An int representing the number of processes diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363..a85a959ab6 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -48,13 +48,13 @@ class CPUAccelerator(Accelerator): @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices