Request `torch.cuda` RNG states only if CUDA is available (#19234)

This commit is contained in:
awaelchli 2024-01-10 22:16:29 +01:00 committed by GitHub
parent 1a1b989457
commit 6bc27d54a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 2 deletions

View File

@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
- Fixed an issue preventing Fabric to run on CPU when the system's CUDA driver is outdated or broken ([#19234](https://github.com/Lightning-AI/lightning/pull/19234))
## [2.1.3] - 2023-12-21
### Fixed

View File

@ -110,7 +110,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
"python": python_get_rng_state(),
}
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all()
states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else []
return states

View File

@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer` not expanding the `default_root_dir` if it has the `~` (home) prefix ([#19179](https://github.com/Lightning-AI/lightning/pull/19179))
- Fixed an issue preventing the Trainer to run on CPU when the system's CUDA driver is outdated or broken ([#19234](https://github.com/Lightning-AI/lightning/pull/19234))
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))

View File

@ -1,5 +1,6 @@
import os
from unittest import mock
from unittest.mock import Mock
import lightning.fabric.utilities
import pytest
@ -85,3 +86,12 @@ def test_backward_compatibility_rng_states_dict():
assert "torch.cuda" in states
states.pop("torch.cuda")
_set_rng_states(states)
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.is_available", Mock(return_value=False))
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.get_rng_state_all")
def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
"""Test that the `torch.cuda` rng states are only requested if CUDA is available."""
get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
states = _collect_rng_states()
assert states["torch.cuda"] == []

View File

@ -47,5 +47,5 @@ def test_isolate_rng_cuda(get_cuda_rng, set_cuda_rng):
set_cuda_rng.assert_not_called()
with isolate_rng(include_cuda=True):
get_cuda_rng.assert_called_once()
assert get_cuda_rng.call_count == int(torch.cuda.is_available())
set_cuda_rng.assert_called_once()