Let Accelerator inherit from ABC to make sure abstractmethod takes effect (#11521)
This commit is contained in:
parent
623dc974f5
commit
350c88e621
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
|
@ -19,7 +19,7 @@ import torch
|
|||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class Accelerator:
|
||||
class Accelerator(ABC):
|
||||
"""The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
|
||||
|
||||
Currently there are accelerators for:
|
||||
|
@ -45,7 +45,7 @@ class Accelerator:
|
|||
"""
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Gets stats for a given device.
|
||||
"""Get stats for a given device.
|
||||
|
||||
Args:
|
||||
device: device for which to get stats
|
||||
|
@ -58,4 +58,4 @@ class Accelerator:
|
|||
@staticmethod
|
||||
@abstractmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Get the devices when set to auto."""
|
||||
"""Get the device count when set to auto."""
|
||||
|
|
|
@ -336,7 +336,9 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir):
|
|||
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
|
||||
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
|
||||
class Accel(Accelerator):
|
||||
pass
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
return 1
|
||||
|
||||
class Prec(PrecisionPlugin):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue