Let Accelerator inherit from ABC to make sure abstractmethod takes effect (#11521)

This commit is contained in:
Chunyang Wen 2022-01-24 03:47:43 +08:00 committed by GitHub
parent 623dc974f5
commit 350c88e621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
import torch
@ -19,7 +19,7 @@ import torch
import pytorch_lightning as pl
class Accelerator:
class Accelerator(ABC):
"""The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for:
@ -45,7 +45,7 @@ class Accelerator:
"""
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.
"""Get stats for a given device.
Args:
device: device for which to get stats
@ -58,4 +58,4 @@ class Accelerator:
@staticmethod
@abstractmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
"""Get the device count when set to auto."""

View File

@ -336,7 +336,9 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir):
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
class Accel(Accelerator):
pass
@staticmethod
def auto_device_count() -> int:
return 1
class Prec(PrecisionPlugin):
pass