diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 093065394b..724b5b6f24 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Dict, Union import torch @@ -19,7 +19,7 @@ import torch import pytorch_lightning as pl -class Accelerator: +class Accelerator(ABC): """The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. Currently there are accelerators for: @@ -45,7 +45,7 @@ class Accelerator: """ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: - """Gets stats for a given device. + """Get stats for a given device. Args: device: device for which to get stats @@ -58,4 +58,4 @@ class Accelerator: @staticmethod @abstractmethod def auto_device_count() -> int: - """Get the devices when set to auto.""" + """Get the device count when set to auto.""" diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 10d1caec5d..3e2ec15216 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -336,7 +336,9 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir): @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): - pass + @staticmethod + def auto_device_count() -> int: + return 1 class Prec(PrecisionPlugin): pass