[1/4] Add get_device_stats to accelerator interface (#9586)
This commit is contained in:
parent
83d83abc9d
commit
ab069876cb
|
@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
|
||||
|
||||
|
||||
- Added `get_device_stats` to the Accelerator Interface and added its implementation for GPU and TPU ([#9586](https://github.com/PyTorchLightning/pytorch-lightning/pull/9586))
|
||||
|
||||
|
||||
- Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073))
|
||||
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ class Accelerator:
|
|||
- CPU
|
||||
- GPU
|
||||
- TPU
|
||||
- IPU
|
||||
|
||||
Each Accelerator gets two plugins upon initialization:
|
||||
One to handle differences from the training routine and one to handle different precisions.
|
||||
|
@ -422,6 +423,17 @@ class Accelerator:
|
|||
"""
|
||||
return self.training_type_plugin.restore_checkpoint_after_pre_dispatch
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Gets stats for a given device.
|
||||
|
||||
Args:
|
||||
device: device for which to get stats
|
||||
|
||||
Returns:
|
||||
Dictionary of device stats
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
"""Called when train begins."""
|
||||
return self.training_type_plugin.on_train_start()
|
||||
|
|
|
@ -11,6 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -29,3 +33,7 @@ class CPUAccelerator(Accelerator):
|
|||
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
|
||||
|
||||
return super().setup(trainer)
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Returns dummy implementation for now."""
|
||||
return {}
|
||||
|
|
|
@ -13,12 +13,16 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,6 +57,83 @@ class GPUAccelerator(Accelerator):
|
|||
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
|
||||
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Gets stats for the given GPU device.
|
||||
|
||||
Args:
|
||||
device: GPU device for which to get stats
|
||||
|
||||
Returns:
|
||||
A dictionary mapping the metrics to their values.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError:
|
||||
If nvidia-smi installation not found
|
||||
"""
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
return torch.cuda.memory_stats(device)
|
||||
return _get_nvidia_gpu_stats(device)
|
||||
|
||||
def teardown(self) -> None:
|
||||
super().teardown()
|
||||
self._move_optimizer_state(torch.device("cpu"))
|
||||
|
||||
|
||||
def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]:
|
||||
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
|
||||
|
||||
Args:
|
||||
device: GPU device for which to get stats
|
||||
|
||||
Returns:
|
||||
A dictionary mapping the metrics to their values.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError:
|
||||
If nvidia-smi installation not found
|
||||
"""
|
||||
gpu_stat_metrics = [
|
||||
("utilization.gpu", "%"),
|
||||
("memory.used", "MB"),
|
||||
("memory.free", "MB"),
|
||||
("utilization.memory", "%"),
|
||||
("fan.speed", "%"),
|
||||
("temperature.gpu", "°C"),
|
||||
("temperature.memory", "°C"),
|
||||
]
|
||||
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
|
||||
gpu_query = ",".join(gpu_stat_keys)
|
||||
|
||||
gpu_id = _get_gpu_id(device.index)
|
||||
nvidia_smi_path = shutil.which("nvidia-smi")
|
||||
if nvidia_smi_path is None:
|
||||
raise FileNotFoundError("nvidia-smi: command not found")
|
||||
result = subprocess.run(
|
||||
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
|
||||
encoding="utf-8",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
|
||||
check=True,
|
||||
)
|
||||
|
||||
def _to_float(x: str) -> float:
|
||||
try:
|
||||
return float(x)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
s = result.stdout.strip()
|
||||
stats = [_to_float(x) for x in s.split(", ")]
|
||||
|
||||
gpu_stats = {}
|
||||
for i, (x, unit) in enumerate(gpu_stat_metrics):
|
||||
gpu_stats[f"{x} ({unit})"] = stats[i]
|
||||
return gpu_stats
|
||||
|
||||
|
||||
def _get_gpu_id(device_id: int) -> str:
|
||||
"""Get the unmasked real GPU IDs."""
|
||||
# All devices if `CUDA_VISIBLE_DEVICES` unset
|
||||
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
|
||||
cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
|
||||
return cuda_visible_devices[device_id].strip()
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
@ -61,3 +61,21 @@ class TPUAccelerator(Accelerator):
|
|||
for opt in self.optimizers:
|
||||
for p, v in opt.state.items():
|
||||
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Gets stats for the given TPU device.
|
||||
|
||||
Args:
|
||||
device: TPU device for which to get stats
|
||||
|
||||
Returns:
|
||||
A dictionary mapping the metrics (free memory and peak memory) to their values.
|
||||
"""
|
||||
memory_info = xm.get_memory_info(device)
|
||||
free_memory = memory_info["kb_free"]
|
||||
peak_memory = memory_info["kb_total"] - free_memory
|
||||
device_stats = {
|
||||
"avg. free memory (MB)": free_memory,
|
||||
"avg. peak memory (MB)": peak_memory,
|
||||
}
|
||||
return device_stats
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.accelerators import GPUAccelerator
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
@RunIf(min_torch="1.8")
|
||||
@RunIf(min_gpus=1)
|
||||
def test_get_torch_gpu_stats(tmpdir):
|
||||
"""Test GPU get_device_stats with Pytorch >= 1.8.0."""
|
||||
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
GPUAccel = GPUAccelerator(
|
||||
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
|
||||
)
|
||||
gpu_stats = GPUAccel.get_device_stats(current_device)
|
||||
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
|
||||
|
||||
for f in fields:
|
||||
assert any(f in h for h in gpu_stats.keys())
|
||||
|
||||
|
||||
@RunIf(max_torch="1.7")
|
||||
@RunIf(min_gpus=1)
|
||||
def test_get_nvidia_gpu_stats(tmpdir):
|
||||
"""Test GPU get_device_stats with Pytorch < 1.8.0."""
|
||||
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
GPUAccel = GPUAccelerator(
|
||||
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
|
||||
)
|
||||
gpu_stats = GPUAccel.get_device_stats(current_device)
|
||||
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
|
||||
|
||||
for f in fields:
|
||||
assert any(f in h for h in gpu_stats.keys())
|
|
@ -0,0 +1,16 @@
|
|||
from pytorch_lightning.accelerators import TPUAccelerator
|
||||
from pytorch_lightning.plugins import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
def test_device_stats_tpu(tmpdir):
|
||||
"""Test TPU get_device_stats."""
|
||||
plugin = SingleTPUPlugin(1)
|
||||
TPUAccel = TPUAccelerator(training_type_plugin=TPUSpawnPlugin(), precision_plugin=plugin)
|
||||
tpu_stats = TPUAccel.get_device_stats("1")
|
||||
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
|
||||
|
||||
for f in fields:
|
||||
assert any(f in h for h in tpu_stats.keys())
|
Loading…
Reference in New Issue