lightning/tests/accelerators/test_gpu.py

37 lines
1.4 KiB
Python
Raw Normal View History

import torch
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from tests.helpers.runif import RunIf
@RunIf(min_torch="1.8")
@RunIf(min_gpus=1)
def test_get_torch_gpu_stats(tmpdir):
"""Test GPU get_device_stats with Pytorch >= 1.8.0."""
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
GPUAccel = GPUAccelerator(
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
)
gpu_stats = GPUAccel.get_device_stats(current_device)
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
for f in fields:
assert any(f in h for h in gpu_stats.keys())
@RunIf(max_torch="1.7")
@RunIf(min_gpus=1)
def test_get_nvidia_gpu_stats(tmpdir):
"""Test GPU get_device_stats with Pytorch < 1.8.0."""
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
GPUAccel = GPUAccelerator(
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
)
gpu_stats = GPUAccel.get_device_stats(current_device)
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
for f in fields:
assert any(f in h for h in gpu_stats.keys())