Add typing to accelerators/gpu.py (#11333)

This commit is contained in:
Carlos Mocholí 2022-01-12 20:44:51 +01:00 committed by GitHub
parent 00d1758bac
commit 5914fb748f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 40 additions and 33 deletions

View File

@ -42,7 +42,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.accelerators.gpu",
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.rich_progress",

View File

@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
from __future__ import annotations
from typing import Any
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _DEVICE
class CPUAccelerator(Accelerator):
@ -28,10 +31,10 @@ class CPUAccelerator(Accelerator):
MisconfigurationException:
If the selected device is not CPU.
"""
if "cpu" not in str(root_device):
if root_device.type != "cpu":
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""CPU device stats aren't supported yet."""
return {}

View File

@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Union
from typing import Any
import torch
@ -23,6 +25,7 @@ import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.types import _DEVICE
_log = logging.getLogger(__name__)
@ -36,11 +39,11 @@ class GPUAccelerator(Accelerator):
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
if root_device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
torch.cuda.set_device(root_device)
def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
# clear cache before training
@ -54,7 +57,7 @@ class GPUAccelerator(Accelerator):
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Gets stats for the given GPU device.
Args:
@ -77,7 +80,7 @@ class GPUAccelerator(Accelerator):
return torch.cuda.device_count()
def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Args:
@ -106,7 +109,8 @@ def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
gpu_query = ",".join(gpu_stat_keys)
gpu_id = _get_gpu_id(device.index)
index = torch._utils._get_device_index(device)
gpu_id = _get_gpu_id(index)
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
@ -130,5 +134,5 @@ def _get_gpu_id(device_id: int) -> str:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
return cuda_visible_devices[device_id].strip()

View File

@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from __future__ import annotations
from typing import Any
import torch
@ -20,6 +22,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.types import _DEVICE
class SingleDeviceStrategy(Strategy):
@ -27,13 +30,13 @@ class SingleDeviceStrategy(Strategy):
def __init__(
self,
device: torch.device,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
device: _DEVICE,
accelerator: pl.accelerators.accelerator.Accelerator | None = None,
checkpoint_io: CheckpointIO | None = None,
precision_plugin: PrecisionPlugin | None = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
self.device: torch.device = device
self._root_device = torch.device(device)
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
@ -46,7 +49,7 @@ class SingleDeviceStrategy(Strategy):
def on_gpu(self) -> bool:
return self.root_device.type == "cuda" and torch.cuda.is_available()
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
operates with a single device, the reduction is simply the identity.
@ -60,18 +63,18 @@ class SingleDeviceStrategy(Strategy):
"""
return tensor
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes."""
return tensor
@property
def root_device(self) -> torch.device:
return self.device
return self._root_device
def model_to_device(self) -> None:
self.model.to(self.root_device)
def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
self.model_to_device()
super().setup(trainer)

View File

@ -36,10 +36,12 @@ class SingleTPUStrategy(SingleDeviceStrategy):
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
):
device = xm.xla_device(device)
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(
accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin
accelerator=accelerator,
device=xm.xla_device(device),
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.debug = debug
@ -60,9 +62,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
super().setup(trainer)
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

View File

@ -752,7 +752,7 @@ class AcceleratorConnector:
plugin = IPUStrategy(parallel_devices=self.parallel_devices)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDeviceStrategy(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu"))
plugin = SingleDeviceStrategy(device=single_gpu_ordinal if self.use_gpu else "cpu")
return plugin
def resolve_strategy(self, training_type: Strategy) -> Strategy:

View File

@ -44,6 +44,7 @@ TRAIN_DATALOADERS = Union[
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
@runtime_checkable

View File

@ -326,7 +326,7 @@ def test_v1_8_0_deprecated_single_device_plugin_class():
" Use `.*SingleDeviceStrategy` instead."
)
):
SingleDevicePlugin(Mock())
SingleDevicePlugin("cpu")
@RunIf(tpu=True)

View File

@ -44,10 +44,9 @@ def test_checkpoint_plugin_called(tmpdir):
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)
model = BoringModel()
device = torch.device("cpu")
trainer = Trainer(
default_root_dir=tmpdir,
strategy=SingleDeviceStrategy(device, checkpoint_io=checkpoint_plugin),
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=2,
)
@ -63,10 +62,9 @@ def test_checkpoint_plugin_called(tmpdir):
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)
model = BoringModel()
device = torch.device("cpu")
trainer = Trainer(
default_root_dir=tmpdir,
strategy=SingleDeviceStrategy(device),
strategy=SingleDeviceStrategy("cpu"),
plugins=[checkpoint_plugin],
callbacks=ck,
max_epochs=2,