Fix: handle logical CUDA device IDs for GPUStatsMonitor if `CUDA_VISIBLE_DEVICES` set (#8260)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Xuehai Pan 2021-07-19 19:42:43 +08:00 committed by GitHub
parent d5bf518cb0
commit 2c5d94d98b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 14 deletions

View File

@ -411,6 +411,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed the `GPUStatsMonitor` callbacks to use the correct GPU IDs if `CUDA_VISIBLE_DEVICES` set ([#8260](https://github.com/PyTorchLightning/pytorch-lightning/pull/8260))
- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))

View File

@ -23,12 +23,16 @@ import os
import shutil
import subprocess
import time
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import DeviceType, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class GPUStatsMonitor(Callback):
@ -101,7 +105,13 @@ class GPUStatsMonitor(Callback):
'temperature': temperature
})
def on_train_start(self, trainer, pl_module) -> None:
# The logical device IDs for selected devices
self._device_ids: List[int] = [] # will be assigned later in setup()
# The unmasked real GPU IDs
self._gpu_ids: List[str] = [] # will be assigned later in setup()
def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
if not trainer.logger:
raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.')
@ -111,14 +121,20 @@ class GPUStatsMonitor(Callback):
f' since gpus attribute in Trainer is set to {trainer.gpus}.'
)
self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids))
# The logical device IDs for selected devices
self._device_ids: List[int] = sorted(set(trainer.data_parallel_device_ids))
def on_train_epoch_start(self, trainer, pl_module) -> None:
# The unmasked real GPU IDs
self._gpu_ids: List[int] = self._get_gpu_ids(self._device_ids)
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
@ -127,7 +143,7 @@ class GPUStatsMonitor(Callback):
gpu_stat_keys = self._get_gpu_stat_keys()
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
@ -137,7 +153,13 @@ class GPUStatsMonitor(Callback):
@rank_zero_only
def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
@ -147,19 +169,28 @@ class GPUStatsMonitor(Callback):
gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000
trainer.logger.log_metrics(logs, step=trainer.global_step)
@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
"""Get the unmasked real GPU IDs"""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ','.join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices: List[str] = os.getenv('CUDA_VISIBLE_DEVICES', default=default).split(',')
return [cuda_visible_devices[device_id].strip() for device_id in device_ids]
def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
"""Run nvidia-smi to get the gpu stats"""
gpu_query = ','.join(queries)
format = 'csv,nounits,noheader'
gpu_ids = ','.join(self._gpu_ids)
result = subprocess.run(
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'],
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={gpu_ids}'],
encoding="utf-8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
@ -177,12 +208,16 @@ class GPUStatsMonitor(Callback):
return stats
@staticmethod
def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]:
def _parse_gpu_stats(
device_ids: List[int],
stats: List[List[float]],
keys: List[Tuple[str, str]],
) -> Dict[str, float]:
"""Parse the gpu stats into a loggable dict"""
logs = {}
for i, gpu_id in enumerate(gpu_ids.split(',')):
for i, device_id in enumerate(device_ids):
for j, (x, unit) in enumerate(keys):
logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j]
logs[f'device_id: {device_id}/{x} ({unit})'] = stats[i][j]
return logs
def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import numpy as np
import pytest
@ -116,6 +117,38 @@ def test_gpu_stats_monitor_no_gpu_warning(tmpdir):
def test_gpu_stats_monitor_parse_gpu_stats():
logs = GPUStatsMonitor._parse_gpu_stats('1,2', [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
expected = {'gpu_id: 1/gpu (a)': 3, 'gpu_id: 1/memory (b)': 4, 'gpu_id: 2/gpu (a)': 6, 'gpu_id: 2/memory (b)': 7}
logs = GPUStatsMonitor._parse_gpu_stats([1, 2], [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
expected = {
'device_id: 1/gpu (a)': 3,
'device_id: 1/memory (b)': 4,
'device_id: 2/gpu (a)': 6,
'device_id: 2/memory (b)': 7
}
assert logs == expected
@mock.patch.dict(os.environ, {})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_unset(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 0])
expected = ['1', '0']
assert gpu_ids == expected
@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '3,2,4'})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=3)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_integers(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
expected = ['2', '4']
assert gpu_ids == expected
@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': 'GPU-01a23b4c,GPU-56d78e9f,GPU-02a46c8e'})
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=3)
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_uuids(device_count_mock, is_available_mock):
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
expected = ['GPU-56d78e9f', 'GPU-02a46c8e']
assert gpu_ids == expected