Better error message when trying to re-initialize CUDA in forked subprocess (#14709)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
9fc4ff3278
commit
ea5e817973
|
@ -24,7 +24,7 @@ from typing_extensions import Literal
|
|||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.strategies.strategy import Strategy
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
|
||||
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
|
||||
|
||||
|
@ -82,6 +82,9 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
*args: Optional positional arguments to be passed to the given function.
|
||||
**kwargs: Optional keyword arguments to be passed to the given function.
|
||||
"""
|
||||
if self._start_method in ("fork", "forkserver"):
|
||||
_check_bad_cuda_fork()
|
||||
|
||||
# The default cluster environment in Lightning chooses a random free port number
|
||||
# This needs to be done in the main process here before starting processes to ensure each rank will connect
|
||||
# through the same port
|
||||
|
@ -166,3 +169,22 @@ class _GlobalStateSnapshot:
|
|||
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
|
||||
torch.backends.cudnn.benchmark = self.cudnn_benchmark
|
||||
_set_rng_states(self.rng_states)
|
||||
|
||||
|
||||
def _check_bad_cuda_fork() -> None:
|
||||
"""Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not.
|
||||
|
||||
The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
|
||||
Lightning users.
|
||||
"""
|
||||
if not torch.cuda.is_initialized():
|
||||
return
|
||||
|
||||
message = (
|
||||
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
|
||||
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
|
||||
" other way? Please remove any such calls, or change the selected strategy."
|
||||
)
|
||||
if _IS_INTERACTIVE:
|
||||
message += " You will have to restart the Python kernel."
|
||||
raise RuntimeError(message)
|
||||
|
|
|
@ -52,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
|
||||
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))
|
||||
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing_extensions import Literal
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork
|
||||
from lightning_lite.utilities import move_data_to_device
|
||||
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
|
@ -90,6 +91,9 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
**kwargs: Optional keyword arguments to be passed to the given function.
|
||||
"""
|
||||
self._check_torchdistx_support()
|
||||
if self._start_method in ("fork", "forkserver"):
|
||||
_check_bad_cuda_fork()
|
||||
|
||||
# The default cluster environment in Lightning chooses a random free port number
|
||||
# This needs to be done in the main process here before starting processes to ensure each rank will connect
|
||||
# through the same port
|
||||
|
|
|
@ -84,3 +84,13 @@ def test_global_state_snapshot():
|
|||
assert torch.are_deterministic_algorithms_enabled()
|
||||
assert not torch.backends.cudnn.benchmark
|
||||
assert torch.initial_seed() == 123
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
|
||||
@mock.patch("torch.cuda.is_initialized", return_value=True)
|
||||
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
|
||||
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
|
||||
mp_mock.get_all_start_methods.return_value = [start_method]
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
||||
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
|
||||
launcher.launch(function=Mock())
|
||||
|
|
Loading…
Reference in New Issue