Remove `List[int]` as input type for Trainer when `accelerator="cpu"` (#20399)

Co-authored-by: Alan Chu <alanchu@Alans-Air.lan>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Alan Chu 2024-11-13 11:42:14 -08:00 committed by GitHub
parent e1b172c62e
commit 20d19d2f57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 6 deletions

View File

@ -39,13 +39,13 @@ class CPUAccelerator(Accelerator):
@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: Union[int, str]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)
@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
@ -72,12 +72,12 @@ class CPUAccelerator(Accelerator):
)
def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
Args:
cpu_cores: An int > 0.
cpu_cores: An int > 0 or a string that can be converted to an int > 0.
Returns:
An int representing the number of processes

View File

@ -48,13 +48,13 @@ class CPUAccelerator(Accelerator):
@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: Union[int, str]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)
@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices