Add typing to accelerators/gpu.py (#11333)
This commit is contained in:
parent
00d1758bac
commit
5914fb748f
|
@ -42,7 +42,6 @@ warn_no_return = "False"
|
|||
# the list can be generated with:
|
||||
# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"pytorch_lightning.accelerators.gpu",
|
||||
"pytorch_lightning.callbacks.finetuning",
|
||||
"pytorch_lightning.callbacks.model_checkpoint",
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
|
|
|
@ -11,12 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class CPUAccelerator(Accelerator):
|
||||
|
@ -28,10 +31,10 @@ class CPUAccelerator(Accelerator):
|
|||
MisconfigurationException:
|
||||
If the selected device is not CPU.
|
||||
"""
|
||||
if "cpu" not in str(root_device):
|
||||
if root_device.type != "cpu":
|
||||
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
|
||||
"""CPU device stats aren't supported yet."""
|
||||
return {}
|
||||
|
||||
|
|
|
@ -11,11 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -23,6 +25,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -36,11 +39,11 @@ class GPUAccelerator(Accelerator):
|
|||
MisconfigurationException:
|
||||
If the selected device is not GPU.
|
||||
"""
|
||||
if "cuda" not in str(root_device):
|
||||
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
|
||||
if root_device.type != "cuda":
|
||||
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
|
||||
torch.cuda.set_device(root_device)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
def setup(self, trainer: pl.Trainer) -> None:
|
||||
# TODO refactor input from trainer to local_rank @four4fish
|
||||
self.set_nvidia_flags(trainer.local_rank)
|
||||
# clear cache before training
|
||||
|
@ -54,7 +57,7 @@ class GPUAccelerator(Accelerator):
|
|||
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
|
||||
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
|
||||
"""Gets stats for the given GPU device.
|
||||
|
||||
Args:
|
||||
|
@ -77,7 +80,7 @@ class GPUAccelerator(Accelerator):
|
|||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
||||
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
|
||||
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
|
||||
|
||||
Args:
|
||||
|
@ -106,7 +109,8 @@ def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
|||
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
|
||||
gpu_query = ",".join(gpu_stat_keys)
|
||||
|
||||
gpu_id = _get_gpu_id(device.index)
|
||||
index = torch._utils._get_device_index(device)
|
||||
gpu_id = _get_gpu_id(index)
|
||||
result = subprocess.run(
|
||||
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
|
||||
encoding="utf-8",
|
||||
|
@ -130,5 +134,5 @@ def _get_gpu_id(device_id: int) -> str:
|
|||
"""Get the unmasked real GPU IDs."""
|
||||
# All devices if `CUDA_VISIBLE_DEVICES` unset
|
||||
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
|
||||
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
|
||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
|
||||
return cuda_visible_devices[device_id].strip()
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -20,6 +22,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class SingleDeviceStrategy(Strategy):
|
||||
|
@ -27,13 +30,13 @@ class SingleDeviceStrategy(Strategy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
device: _DEVICE,
|
||||
accelerator: pl.accelerators.accelerator.Accelerator | None = None,
|
||||
checkpoint_io: CheckpointIO | None = None,
|
||||
precision_plugin: PrecisionPlugin | None = None,
|
||||
):
|
||||
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
|
||||
self.device: torch.device = device
|
||||
self._root_device = torch.device(device)
|
||||
self.global_rank = 0
|
||||
self.local_rank = 0
|
||||
self.world_size = 1
|
||||
|
@ -46,7 +49,7 @@ class SingleDeviceStrategy(Strategy):
|
|||
def on_gpu(self) -> bool:
|
||||
return self.root_device.type == "cuda" and torch.cuda.is_available()
|
||||
|
||||
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
|
||||
def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor:
|
||||
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
|
||||
operates with a single device, the reduction is simply the identity.
|
||||
|
||||
|
@ -60,18 +63,18 @@ class SingleDeviceStrategy(Strategy):
|
|||
"""
|
||||
return tensor
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""Perform a all_gather on all processes."""
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def root_device(self) -> torch.device:
|
||||
return self.device
|
||||
return self._root_device
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
def setup(self, trainer: pl.Trainer) -> None:
|
||||
self.model_to_device()
|
||||
super().setup(trainer)
|
||||
|
||||
|
|
|
@ -36,10 +36,12 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
|||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
device = xm.xla_device(device)
|
||||
checkpoint_io = checkpoint_io or XLACheckpointIO()
|
||||
super().__init__(
|
||||
accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin
|
||||
accelerator=accelerator,
|
||||
device=xm.xla_device(device),
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
|
||||
self.debug = debug
|
||||
|
@ -60,9 +62,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
|||
|
||||
super().setup(trainer)
|
||||
|
||||
if isinstance(self.device, int):
|
||||
self.device = xm.xla_device(self.device)
|
||||
|
||||
if self.debug:
|
||||
os.environ["PT_XLA_DEBUG"] = str(1)
|
||||
|
||||
|
|
|
@ -752,7 +752,7 @@ class AcceleratorConnector:
|
|||
plugin = IPUStrategy(parallel_devices=self.parallel_devices)
|
||||
else:
|
||||
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
|
||||
plugin = SingleDeviceStrategy(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu"))
|
||||
plugin = SingleDeviceStrategy(device=single_gpu_ordinal if self.use_gpu else "cpu")
|
||||
return plugin
|
||||
|
||||
def resolve_strategy(self, training_type: Strategy) -> Strategy:
|
||||
|
|
|
@ -44,6 +44,7 @@ TRAIN_DATALOADERS = Union[
|
|||
Dict[str, Sequence[DataLoader]],
|
||||
]
|
||||
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
|
||||
_DEVICE = Union[torch.device, str, int]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -326,7 +326,7 @@ def test_v1_8_0_deprecated_single_device_plugin_class():
|
|||
" Use `.*SingleDeviceStrategy` instead."
|
||||
)
|
||||
):
|
||||
SingleDevicePlugin(Mock())
|
||||
SingleDevicePlugin("cpu")
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
|
|
|
@ -44,10 +44,9 @@ def test_checkpoint_plugin_called(tmpdir):
|
|||
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
|
||||
model = BoringModel()
|
||||
device = torch.device("cpu")
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
strategy=SingleDeviceStrategy(device, checkpoint_io=checkpoint_plugin),
|
||||
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
|
||||
callbacks=ck,
|
||||
max_epochs=2,
|
||||
)
|
||||
|
@ -63,10 +62,9 @@ def test_checkpoint_plugin_called(tmpdir):
|
|||
ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
|
||||
model = BoringModel()
|
||||
device = torch.device("cpu")
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
strategy=SingleDeviceStrategy(device),
|
||||
strategy=SingleDeviceStrategy("cpu"),
|
||||
plugins=[checkpoint_plugin],
|
||||
callbacks=ck,
|
||||
max_epochs=2,
|
||||
|
|
Loading…
Reference in New Issue