diff --git a/CHANGELOG.md b/CHANGELOG.md index 95625ac33a..842c41f179 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -411,6 +411,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed the `GPUStatsMonitor` callbacks to use the correct GPU IDs if `CUDA_VISIBLE_DEVICES` set ([#8260](https://github.com/PyTorchLightning/pytorch-lightning/pull/8260)) + - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 794165fe60..56cb076927 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -23,12 +23,16 @@ import os import shutil import subprocess import time -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple +import torch + +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import DeviceType, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.types import STEP_OUTPUT class GPUStatsMonitor(Callback): @@ -101,7 +105,13 @@ class GPUStatsMonitor(Callback): 'temperature': temperature }) - def on_train_start(self, trainer, pl_module) -> None: + # The logical device IDs for selected devices + self._device_ids: List[int] = [] # will be assigned later in setup() + + # The unmasked real GPU IDs + self._gpu_ids: List[str] = [] # will be assigned later in setup() + + def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None: if not trainer.logger: raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') @@ -111,14 +121,20 @@ class GPUStatsMonitor(Callback): f' since gpus attribute in Trainer is set to {trainer.gpus}.' ) - self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) + # The logical device IDs for selected devices + self._device_ids: List[int] = sorted(set(trainer.data_parallel_device_ids)) - def on_train_epoch_start(self, trainer, pl_module) -> None: + # The unmasked real GPU IDs + self._gpu_ids: List[int] = self._get_gpu_ids(self._device_ids) + + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._snap_intra_step_time = None self._snap_inter_step_time = None @rank_zero_only - def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: if self._log_stats.intra_step_time: self._snap_intra_step_time = time.time() @@ -127,7 +143,7 @@ class GPUStatsMonitor(Callback): gpu_stat_keys = self._get_gpu_stat_keys() gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) - logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys) if self._log_stats.inter_step_time and self._snap_inter_step_time: # First log at beginning of second step @@ -137,7 +153,13 @@ class GPUStatsMonitor(Callback): @rank_zero_only def on_train_batch_end( - self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, + trainer: 'pl.Trainer', + pl_module: 'pl.LightningModule', + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int, ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() @@ -147,19 +169,28 @@ class GPUStatsMonitor(Callback): gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys() gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) - logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys) if self._log_stats.intra_step_time and self._snap_intra_step_time: logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000 trainer.logger.log_metrics(logs, step=trainer.global_step) + @staticmethod + def _get_gpu_ids(device_ids: List[int]) -> List[str]: + """Get the unmasked real GPU IDs""" + # All devices if `CUDA_VISIBLE_DEVICES` unset + default = ','.join(str(i) for i in range(torch.cuda.device_count())) + cuda_visible_devices: List[str] = os.getenv('CUDA_VISIBLE_DEVICES', default=default).split(',') + return [cuda_visible_devices[device_id].strip() for device_id in device_ids] + def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: """Run nvidia-smi to get the gpu stats""" gpu_query = ','.join(queries) format = 'csv,nounits,noheader' + gpu_ids = ','.join(self._gpu_ids) result = subprocess.run( - [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'], + [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={gpu_ids}'], encoding="utf-8", stdout=subprocess.PIPE, stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 @@ -177,12 +208,16 @@ class GPUStatsMonitor(Callback): return stats @staticmethod - def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]: + def _parse_gpu_stats( + device_ids: List[int], + stats: List[List[float]], + keys: List[Tuple[str, str]], + ) -> Dict[str, float]: """Parse the gpu stats into a loggable dict""" logs = {} - for i, gpu_id in enumerate(gpu_ids.split(',')): + for i, device_id in enumerate(device_ids): for j, (x, unit) in enumerate(keys): - logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j] + logs[f'device_id: {device_id}/{x} ({unit})'] = stats[i][j] return logs def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py index e7b5ce1727..85a184e1e9 100644 --- a/tests/callbacks/test_gpu_stats_monitor.py +++ b/tests/callbacks/test_gpu_stats_monitor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock import numpy as np import pytest @@ -116,6 +117,38 @@ def test_gpu_stats_monitor_no_gpu_warning(tmpdir): def test_gpu_stats_monitor_parse_gpu_stats(): - logs = GPUStatsMonitor._parse_gpu_stats('1,2', [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')]) - expected = {'gpu_id: 1/gpu (a)': 3, 'gpu_id: 1/memory (b)': 4, 'gpu_id: 2/gpu (a)': 6, 'gpu_id: 2/memory (b)': 7} + logs = GPUStatsMonitor._parse_gpu_stats([1, 2], [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')]) + expected = { + 'device_id: 1/gpu (a)': 3, + 'device_id: 1/memory (b)': 4, + 'device_id: 2/gpu (a)': 6, + 'device_id: 2/memory (b)': 7 + } assert logs == expected + + +@mock.patch.dict(os.environ, {}) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_unset(device_count_mock, is_available_mock): + gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 0]) + expected = ['1', '0'] + assert gpu_ids == expected + + +@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '3,2,4'}) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.device_count', return_value=3) +def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_integers(device_count_mock, is_available_mock): + gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2]) + expected = ['2', '4'] + assert gpu_ids == expected + + +@mock.patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': 'GPU-01a23b4c,GPU-56d78e9f,GPU-02a46c8e'}) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.device_count', return_value=3) +def test_gpu_stats_monitor_get_gpu_ids_cuda_visible_devices_uuids(device_count_mock, is_available_mock): + gpu_ids = GPUStatsMonitor._get_gpu_ids([1, 2]) + expected = ['GPU-56d78e9f', 'GPU-02a46c8e'] + assert gpu_ids == expected