Let Accelerator inherit from ABC to make sure abstractmethod takes effect (#11521)
This commit is contained in:
parent
623dc974f5
commit
350c88e621
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -19,7 +19,7 @@ import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
|
||||||
class Accelerator:
|
class Accelerator(ABC):
|
||||||
"""The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
|
"""The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
|
||||||
|
|
||||||
Currently there are accelerators for:
|
Currently there are accelerators for:
|
||||||
|
@ -45,7 +45,7 @@ class Accelerator:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||||
"""Gets stats for a given device.
|
"""Get stats for a given device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device: device for which to get stats
|
device: device for which to get stats
|
||||||
|
@ -58,4 +58,4 @@ class Accelerator:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def auto_device_count() -> int:
|
def auto_device_count() -> int:
|
||||||
"""Get the devices when set to auto."""
|
"""Get the device count when set to auto."""
|
||||||
|
|
|
@ -336,7 +336,9 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir):
|
||||||
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
|
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
|
||||||
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
|
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
|
||||||
class Accel(Accelerator):
|
class Accel(Accelerator):
|
||||||
pass
|
@staticmethod
|
||||||
|
def auto_device_count() -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
class Prec(PrecisionPlugin):
|
class Prec(PrecisionPlugin):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue