diff --git a/src/lightning_lite/strategies/launchers/multiprocessing.py b/src/lightning_lite/strategies/launchers/multiprocessing.py index 20cf765f76..7fb161b2ed 100644 --- a/src/lightning_lite/strategies/launchers/multiprocessing.py +++ b/src/lightning_lite/strategies/launchers/multiprocessing.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from lightning_lite.strategies.launchers.base import _Launcher from lightning_lite.strategies.strategy import Strategy from lightning_lite.utilities.apply_func import move_data_to_device -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11 from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states @@ -82,6 +82,9 @@ class _MultiProcessingLauncher(_Launcher): *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. """ + if self._start_method in ("fork", "forkserver"): + _check_bad_cuda_fork() + # The default cluster environment in Lightning chooses a random free port number # This needs to be done in the main process here before starting processes to ensure each rank will connect # through the same port @@ -166,3 +169,22 @@ class _GlobalStateSnapshot: torch.use_deterministic_algorithms(self.use_deterministic_algorithms) torch.backends.cudnn.benchmark = self.cudnn_benchmark _set_rng_states(self.rng_states) + + +def _check_bad_cuda_fork() -> None: + """Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not. + + The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for + Lightning users. + """ + if not torch.cuda.is_initialized(): + return + + message = ( + "Lightning can't create new processes if CUDA is already initialized. Did you manually call" + " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 408efc3147..89d4847352 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -52,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709)) + + + ### Changed - The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892)) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index dc5916e3e2..de41b8ff2b 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -27,6 +27,7 @@ from typing_extensions import Literal import pytorch_lightning as pl from lightning_lite.strategies.launchers.base import _Launcher +from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states from lightning_lite.utilities.types import _PATH @@ -90,6 +91,9 @@ class _MultiProcessingLauncher(_Launcher): **kwargs: Optional keyword arguments to be passed to the given function. """ self._check_torchdistx_support() + if self._start_method in ("fork", "forkserver"): + _check_bad_cuda_fork() + # The default cluster environment in Lightning chooses a random free port number # This needs to be done in the main process here before starting processes to ensure each rank will connect # through the same port diff --git a/tests/tests_lite/strategies/launchers/test_multiprocessing.py b/tests/tests_lite/strategies/launchers/test_multiprocessing.py index fef19f0671..5bdff3cb62 100644 --- a/tests/tests_lite/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_lite/strategies/launchers/test_multiprocessing.py @@ -84,3 +84,13 @@ def test_global_state_snapshot(): assert torch.are_deterministic_algorithms_enabled() assert not torch.backends.cudnn.benchmark assert torch.initial_seed() == 123 + + +@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) +@mock.patch("torch.cuda.is_initialized", return_value=True) +@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp") +def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method): + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"): + launcher.launch(function=Mock())