Let Accelerator inherit from ABC to make sure abstractmethod takes effect (#11521)

This commit is contained in:
Chunyang Wen 2022-01-24 03:47:43 +08:00 committed by GitHub
parent 623dc974f5
commit 350c88e621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Union from typing import Any, Dict, Union
import torch import torch
@ -19,7 +19,7 @@ import torch
import pytorch_lightning as pl import pytorch_lightning as pl
class Accelerator: class Accelerator(ABC):
"""The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. """The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for: Currently there are accelerators for:
@ -45,7 +45,7 @@ class Accelerator:
""" """
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device. """Get stats for a given device.
Args: Args:
device: device for which to get stats device: device for which to get stats
@ -58,4 +58,4 @@ class Accelerator:
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def auto_device_count() -> int: def auto_device_count() -> int:
"""Get the devices when set to auto.""" """Get the device count when set to auto."""

View File

@ -336,7 +336,9 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir):
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock): def test_custom_accelerator(device_count_mock, setup_distributed_mock):
class Accel(Accelerator): class Accel(Accelerator):
pass @staticmethod
def auto_device_count() -> int:
return 1
class Prec(PrecisionPlugin): class Prec(PrecisionPlugin):
pass pass