Better error message when trying to re-initialize CUDA in forked subprocess (#14709)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-09-28 11:07:33 +02:00 committed by GitHub
parent 9fc4ff3278
commit ea5e817973
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 1 deletions

View File

@ -24,7 +24,7 @@ from typing_extensions import Literal
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.strategy import Strategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
@ -82,6 +82,9 @@ class _MultiProcessingLauncher(_Launcher):
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
@ -166,3 +169,22 @@ class _GlobalStateSnapshot:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)
def _check_bad_cuda_fork() -> None:
"""Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not.
The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
Lightning users.
"""
if not torch.cuda.is_initialized():
return
message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)

View File

@ -52,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))
### Changed
- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))

View File

@ -27,6 +27,7 @@ from typing_extensions import Literal
import pytorch_lightning as pl
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_lite.utilities.types import _PATH
@ -90,6 +91,9 @@ class _MultiProcessingLauncher(_Launcher):
**kwargs: Optional keyword arguments to be passed to the given function.
"""
self._check_torchdistx_support()
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port

View File

@ -84,3 +84,13 @@ def test_global_state_snapshot():
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
launcher.launch(function=Mock())