Fix: handle logical CUDA device IDs for GPUStatsMonitor if `CUDA_VISIBLE_DEVICES` set (#8260)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
d5bf518cb0
commit
2c5d94d98b
|
@ -411,6 +411,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed the `GPUStatsMonitor` callbacks to use the correct GPU IDs if `CUDA_VISIBLE_DEVICES` set ([#8260](https://github.com/PyTorchLightning/pytorch-lightning/pull/8260))
|
||||
|
||||
- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
|
||||
|
||||
|
||||
|
|
|
@ -23,12 +23,16 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import DeviceType, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
class GPUStatsMonitor(Callback):
|
||||
|
@ -101,7 +105,13 @@ class GPUStatsMonitor(Callback):
|
|||
'temperature': temperature
|
||||
})
|
||||
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
# The logical device IDs for selected devices
|
||||
self._device_ids: List[int] = [] # will be assigned later in setup()
|
||||
|
||||
# The unmasked real GPU IDs
|
||||
self._gpu_ids: List[str] = [] # will be assigned later in setup()
|
||||
|
||||
def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.')
|
||||
|
||||
|
@ -111,14 +121,20 @@ class GPUStatsMonitor(Callback):
|
|||
f' since gpus attribute in Trainer is set to {trainer.gpus}.'
|
||||
)
|
||||
|
||||
self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids))
|
||||
# The logical device IDs for selected devices
|
||||
self._device_ids: List[int] = sorted(set(trainer.data_parallel_device_ids))
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module) -> None:
|
||||
# The unmasked real GPU IDs
|
||||
self._gpu_ids: List[int] = self._get_gpu_ids(self._device_ids)
|
||||
|
||||
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
|
||||
self._snap_intra_step_time = None
|
||||
self._snap_inter_step_time = None
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_train_batch_start(
|
||||
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
if self._log_stats.intra_step_time:
|
||||
self._snap_intra_step_time = time.time()
|
||||
|
||||
|
@ -127,7 +143,7 @@ class GPUStatsMonitor(Callback):
|
|||
|
||||
gpu_stat_keys = self._get_gpu_stat_keys()
|
||||
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
|
||||
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
|
||||
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
|
||||
|
||||
if self._log_stats.inter_step_time and self._snap_inter_step_time:
|
||||
# First log at beginning of second step
|
||||
|
@ -137,7 +153,13 @@ class GPUStatsMonitor(Callback):
|
|||
|
||||
@rank_zero_only
|
||||
def on_train_batch_end(
|
||||
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
self,
|
||||
trainer: 'pl.Trainer',
|
||||
pl_module: 'pl.LightningModule',
|
||||
outputs: STEP_OUTPUT,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int,
|
||||
) -> None:
|
||||
if self._log_stats.inter_step_time:
|
||||
self._snap_inter_step_time = time.time()
|
||||
|
@ -147,19 +169,28 @@ class GPUStatsMonitor(Callback):
|
|||
|
||||
gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
|
||||
gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
|
||||
logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys)
|
||||
logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
|
||||
|
||||
if self._log_stats.intra_step_time and self._snap_intra_step_time:
|
||||
logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000
|
||||
|
||||
trainer.logger.log_metrics(logs, step=trainer.global_step)
|
||||
|
||||
@staticmethod
|
||||
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
|
||||
"""Get the unmasked real GPU IDs"""
|
||||
# All devices if `CUDA_VISIBLE_DEVICES` unset
|
||||
default = ','.join(str(i) for i in range(torch.cuda.device_count()))
|
||||
cuda_visible_devices: List[str] = os.getenv('CUDA_VISIBLE_DEVICES', default=default).split(',')
|
||||
return [cuda_visible_devices[device_id].strip() for device_id in device_ids]
|
||||
|
||||
def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
|
||||
"""Run nvidia-smi to get the gpu stats"""
|
||||
gpu_query = ','.join(queries)
|
||||
format = 'csv,nounits,noheader'
|
||||
gpu_ids = ','.join(self._gpu_ids)
|
||||
result = subprocess.run(
|
||||
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'],
|
||||
[shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={gpu_ids}'],
|
||||
encoding="utf-8",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
|
||||
|
@ -177,12 +208,16 @@ class GPUStatsMonitor(Callback):
|
|||
return stats
|
||||
|
||||
@staticmethod
|
||||
def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]:
|
||||
def _parse_gpu_stats(
|
||||
device_ids: List[int],
|
||||
stats: List[List[float]],
|
||||
keys: List[Tuple[str, str]],
|
||||
) -> Dict[str, float]:
|
||||
"""Parse the gpu stats into a loggable dict"""
|
||||
logs = {}
|
||||
for i, gpu_id in enumerate(gpu_ids.split(',')):
|
||||
for i, device_id in enumerate(device_ids):
|
||||
for j, (x, unit) in enumerate(keys):
|
||||
logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j]
|
||||
logs[f'device_id: {device_id}/{x} ({unit})'] = stats[i][j]
|
||||
return logs
|
||||
|
||||
def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -116,6 +117,38 @@ def test_gpu_stats_monitor_no_gpu_warning(tmpdir):
|
|||
|
||||
|
||||
def test_gpu_stats_monitor_parse_gpu_stats():
|
||||
logs = GPUStatsMonitor._parse_gpu_stats('1,2', [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
|
||||
expected = {'gpu_id: 1/gpu (a)': 3, 'gpu_id: 1/memory (b)': 4, 'gpu_id: 2/gpu (a)': 6, 'gpu_id: 2/memory (b)': 7}
|
||||
logs = GPUStatsMonitor._parse_gpu_stats([1, 2], [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')])
|
||||
expected = {
|
||||
'device_id: 1/gpu (a)': 3,
|
||||
'device_id: 1/memory (b)': 4,
|
||||
'device_id: 2/gpu (a)': 6,
|
||||
'device_id: 2/memory (b)': 7
|
||||
}
|
||||
assert logs == expected
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {})
|
||||
@mock.patch('torch.cuda.is_available', return_value=True)
|
||||
@mock.patch('torch.cuda.device_count', return_value=2)
|
||||
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_unset(device_count_mock, is_available_mock):
|
||||
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 0])
|
||||
expected = ['1', '0']
|
||||
assert gpu_ids == expected
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '3,2,4'})
|
||||
@mock.patch('torch.cuda.is_available', return_value=True)
|
||||
@mock.patch('torch.cuda.device_count', return_value=3)
|
||||
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_integers(device_count_mock, is_available_mock):
|
||||
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
|
||||
expected = ['2', '4']
|
||||
assert gpu_ids == expected
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': 'GPU-01a23b4c,GPU-56d78e9f,GPU-02a46c8e'})
|
||||
@mock.patch('torch.cuda.is_available', return_value=True)
|
||||
@mock.patch('torch.cuda.device_count', return_value=3)
|
||||
def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_uuids(device_count_mock, is_available_mock):
|
||||
gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2])
|
||||
expected = ['GPU-56d78e9f', 'GPU-02a46c8e']
|
||||
assert gpu_ids == expected
|
||||
|
|
Loading…
Reference in New Issue