diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 675d2ca0dc..0cff349fdc 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) +- Fixed an issue preventing Fabric to run on CPU when the system's CUDA driver is outdated or broken ([#19234](https://github.com/Lightning-AI/lightning/pull/19234)) + + ## [2.1.3] - 2023-12-21 ### Fixed diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 2c3169fc8f..b274bce88f 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -110,7 +110,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: "python": python_get_rng_state(), } if include_cuda: - states["torch.cuda"] = torch.cuda.get_rng_state_all() + states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [] return states diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f74e11f30d..274e647bd4 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not expanding the `default_root_dir` if it has the `~` (home) prefix ([#19179](https://github.com/Lightning-AI/lightning/pull/19179)) +- Fixed an issue preventing the Trainer to run on CPU when the system's CUDA driver is outdated or broken ([#19234](https://github.com/Lightning-AI/lightning/pull/19234)) + + - Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224)) diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index f371afe96c..351f6a47b7 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -1,5 +1,6 @@ import os from unittest import mock +from unittest.mock import Mock import lightning.fabric.utilities import pytest @@ -85,3 +86,12 @@ def test_backward_compatibility_rng_states_dict(): assert "torch.cuda" in states states.pop("torch.cuda") _set_rng_states(states) + + +@mock.patch("lightning.fabric.utilities.seed.torch.cuda.is_available", Mock(return_value=False)) +@mock.patch("lightning.fabric.utilities.seed.torch.cuda.get_rng_state_all") +def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock): + """Test that the `torch.cuda` rng states are only requested if CUDA is available.""" + get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old") + states = _collect_rng_states() + assert states["torch.cuda"] == [] diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index dc104afa5a..282009f6b9 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -47,5 +47,5 @@ def test_isolate_rng_cuda(get_cuda_rng, set_cuda_rng): set_cuda_rng.assert_not_called() with isolate_rng(include_cuda=True): - get_cuda_rng.assert_called_once() + assert get_cuda_rng.call_count == int(torch.cuda.is_available()) set_cuda_rng.assert_called_once()