Remove `List[int]` as input type for Trainer when `accelerator="cpu"` (#20399)
Co-authored-by: Alan Chu <alanchu@Alans-Air.lan> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
e1b172c62e
commit
20d19d2f57
|
@ -39,13 +39,13 @@ class CPUAccelerator(Accelerator):
|
|||
|
||||
@staticmethod
|
||||
@override
|
||||
def parse_devices(devices: Union[int, str, List[int]]) -> int:
|
||||
def parse_devices(devices: Union[int, str]) -> int:
|
||||
"""Accelerator device parsing logic."""
|
||||
return _parse_cpu_cores(devices)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
|
||||
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
|
||||
"""Gets parallel devices for the Accelerator."""
|
||||
devices = _parse_cpu_cores(devices)
|
||||
return [torch.device("cpu")] * devices
|
||||
|
@ -72,12 +72,12 @@ class CPUAccelerator(Accelerator):
|
|||
)
|
||||
|
||||
|
||||
def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
|
||||
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:
|
||||
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
|
||||
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
cpu_cores: An int > 0.
|
||||
cpu_cores: An int > 0 or a string that can be converted to an int > 0.
|
||||
|
||||
Returns:
|
||||
An int representing the number of processes
|
||||
|
|
|
@ -48,13 +48,13 @@ class CPUAccelerator(Accelerator):
|
|||
|
||||
@staticmethod
|
||||
@override
|
||||
def parse_devices(devices: Union[int, str, List[int]]) -> int:
|
||||
def parse_devices(devices: Union[int, str]) -> int:
|
||||
"""Accelerator device parsing logic."""
|
||||
return _parse_cpu_cores(devices)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
|
||||
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
|
||||
"""Gets parallel devices for the Accelerator."""
|
||||
devices = _parse_cpu_cores(devices)
|
||||
return [torch.device("cpu")] * devices
|
||||
|
|
Loading…
Reference in New Issue