lightning/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

179 lines
7.7 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing import Process
from unittest import mock
from unittest.mock import ANY, call, Mock, patch
import pytest
import torch
from lightning.fabric.plugins import ClusterEnvironment
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf
2022-07-23 15:48:15 +00:00
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
2022-07-23 15:48:15 +00:00
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
2022-07-23 15:48:15 +00:00
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
2022-07-23 15:48:15 +00:00
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
mp_mock.get_context.assert_called_with(start_method)
mp_mock.start_processes.assert_called_with(
ANY,
args=ANY,
nprocs=ANY,
start_method=start_method,
join=False,
)
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
function_args = mp_mock.start_processes.call_args[1]["args"]
if start_method == "spawn":
assert len(function_args) == 6
assert isinstance(function_args[5], _GlobalStateSnapshot)
else:
assert len(function_args) == 5
def test_global_state_snapshot():
"""Test the capture() and restore() methods for the global state snapshot."""
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.manual_seed(123)
# capture the state of globals
snapshot = _GlobalStateSnapshot.capture()
# simulate there is a process boundary and flags get reset here
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.benchmark = True
torch.manual_seed(321)
# restore the state of globals
snapshot.restore()
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
@pytest.mark.parametrize("fake_node_rank", [0, 1])
@pytest.mark.parametrize("fake_local_rank", [0, 1])
def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir):
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
fake_global_rank = 2 * fake_node_rank + fake_local_rank
cluster_environment = Mock(spec=ClusterEnvironment)
cluster_environment.world_size.return_value = 4
cluster_environment.node_rank.return_value = fake_node_rank
cluster_environment.local_rank.return_value = fake_local_rank
cluster_environment.global_rank.return_value = fake_global_rank
strategy = DDPSpawnStrategy(cluster_environment=cluster_environment)
strategy._local_rank = fake_local_rank
launcher = _MultiProcessingLauncher(strategy=strategy)
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
assert strategy.node_rank == fake_node_rank
assert strategy.local_rank == fake_local_rank
assert strategy.global_rank == fake_global_rank
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
spawn_output = launcher._collect_rank_zero_results(trainer, {})
model.state_dict.assert_called_once()
is_fitting = trainer_fn == TrainerFn.FITTING
if strategy.local_rank == 0:
# on local rank 0 (each node), we expect a temp checkpoint (when fitting)
assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt")
assert not is_fitting or os.path.isfile(spawn_output.weights_path)
assert is_fitting or spawn_output.weights_path is None
else:
# all other ranks don't have outputs (rank 0 needs to handle the output)
assert spawn_output is None
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_transfer_weights(tmpdir, trainer_fn):
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
def test_non_strict_loading(tmpdir):
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
disabled, not erroring for missing keys."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPSpawnStrategy()
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
model.load_state_dict.assert_called_once_with(ANY, strict=False)
def test_kill():
launcher = _MultiProcessingLauncher(Mock())
proc0 = Mock(autospec=Process)
proc1 = Mock(autospec=Process)
launcher.procs = [proc0, proc1]
with patch("os.kill") as kill_patch:
launcher.kill(15)
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]