2022-07-22 16:05:35 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-11-08 12:07:58 +00:00
|
|
|
import os
|
2023-01-31 16:24:58 +00:00
|
|
|
from multiprocessing import Process
|
2022-07-22 16:05:35 +00:00
|
|
|
from unittest import mock
|
2023-01-31 16:24:58 +00:00
|
|
|
from unittest.mock import ANY, call, Mock, patch
|
2022-07-22 16:05:35 +00:00
|
|
|
|
|
|
|
import pytest
|
2022-08-01 22:21:46 +00:00
|
|
|
import torch
|
2022-07-22 16:05:35 +00:00
|
|
|
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.plugins import ClusterEnvironment
|
2022-11-08 12:07:58 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.demos.boring_classes import BoringModel
|
|
|
|
from pytorch_lightning.strategies import DDPSpawnStrategy
|
2022-08-01 22:21:46 +00:00
|
|
|
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
|
2022-11-08 12:07:58 +00:00
|
|
|
from pytorch_lightning.trainer.states import TrainerFn
|
2022-10-04 22:54:14 +00:00
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
2022-07-22 16:05:35 +00:00
|
|
|
|
|
|
|
|
2022-07-23 15:48:15 +00:00
|
|
|
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
|
2022-08-01 22:21:46 +00:00
|
|
|
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
|
2022-07-22 16:05:35 +00:00
|
|
|
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
|
2022-07-23 15:48:15 +00:00
|
|
|
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
|
2022-07-22 16:05:35 +00:00
|
|
|
|
|
|
|
|
2022-10-04 22:54:14 +00:00
|
|
|
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
|
2022-07-23 15:48:15 +00:00
|
|
|
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
|
2022-08-01 22:21:46 +00:00
|
|
|
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
|
2022-07-22 16:05:35 +00:00
|
|
|
mp_mock.get_all_start_methods.return_value = [start_method]
|
2022-07-23 15:48:15 +00:00
|
|
|
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
2022-07-22 16:05:35 +00:00
|
|
|
launcher.launch(function=Mock())
|
|
|
|
mp_mock.get_context.assert_called_with(start_method)
|
|
|
|
mp_mock.start_processes.assert_called_with(
|
|
|
|
ANY,
|
|
|
|
args=ANY,
|
|
|
|
nprocs=ANY,
|
|
|
|
start_method=start_method,
|
2023-01-31 16:24:58 +00:00
|
|
|
join=False,
|
2022-07-22 16:05:35 +00:00
|
|
|
)
|
2022-08-01 22:21:46 +00:00
|
|
|
|
|
|
|
|
2022-10-04 22:54:14 +00:00
|
|
|
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
|
2022-08-01 22:21:46 +00:00
|
|
|
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
|
|
|
|
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
|
|
|
|
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
|
|
|
|
mp_mock.get_all_start_methods.return_value = [start_method]
|
|
|
|
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
|
|
|
launcher.launch(function=Mock())
|
|
|
|
function_args = mp_mock.start_processes.call_args[1]["args"]
|
|
|
|
if start_method == "spawn":
|
|
|
|
assert len(function_args) == 6
|
|
|
|
assert isinstance(function_args[5], _GlobalStateSnapshot)
|
|
|
|
else:
|
|
|
|
assert len(function_args) == 5
|
|
|
|
|
|
|
|
|
|
|
|
def test_global_state_snapshot():
|
|
|
|
"""Test the capture() and restore() methods for the global state snapshot."""
|
|
|
|
torch.use_deterministic_algorithms(True)
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
|
|
# capture the state of globals
|
|
|
|
snapshot = _GlobalStateSnapshot.capture()
|
|
|
|
|
|
|
|
# simulate there is a process boundary and flags get reset here
|
|
|
|
torch.use_deterministic_algorithms(False)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
torch.manual_seed(321)
|
|
|
|
|
|
|
|
# restore the state of globals
|
|
|
|
snapshot.restore()
|
|
|
|
assert torch.are_deterministic_algorithms_enabled()
|
|
|
|
assert not torch.backends.cudnn.benchmark
|
|
|
|
assert torch.initial_seed() == 123
|
2022-11-08 12:07:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
|
|
|
|
@pytest.mark.parametrize("fake_node_rank", [0, 1])
|
|
|
|
@pytest.mark.parametrize("fake_local_rank", [0, 1])
|
|
|
|
def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir):
|
|
|
|
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
|
|
|
|
file."""
|
|
|
|
model = Mock(wraps=BoringModel(), spec=BoringModel)
|
|
|
|
fake_global_rank = 2 * fake_node_rank + fake_local_rank
|
|
|
|
|
|
|
|
cluster_environment = Mock(spec=ClusterEnvironment)
|
|
|
|
cluster_environment.world_size.return_value = 4
|
|
|
|
cluster_environment.node_rank.return_value = fake_node_rank
|
|
|
|
cluster_environment.local_rank.return_value = fake_local_rank
|
|
|
|
cluster_environment.global_rank.return_value = fake_global_rank
|
|
|
|
|
|
|
|
strategy = DDPSpawnStrategy(cluster_environment=cluster_environment)
|
|
|
|
strategy._local_rank = fake_local_rank
|
|
|
|
|
|
|
|
launcher = _MultiProcessingLauncher(strategy=strategy)
|
2023-01-18 22:53:49 +00:00
|
|
|
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
|
2022-11-08 12:07:58 +00:00
|
|
|
|
|
|
|
assert strategy.node_rank == fake_node_rank
|
|
|
|
assert strategy.local_rank == fake_local_rank
|
|
|
|
assert strategy.global_rank == fake_global_rank
|
|
|
|
|
|
|
|
trainer.strategy.connect(model)
|
|
|
|
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
|
|
|
|
|
|
|
|
spawn_output = launcher._collect_rank_zero_results(trainer, {})
|
|
|
|
|
|
|
|
model.state_dict.assert_called_once()
|
|
|
|
is_fitting = trainer_fn == TrainerFn.FITTING
|
|
|
|
if strategy.local_rank == 0:
|
|
|
|
# on local rank 0 (each node), we expect a temp checkpoint (when fitting)
|
|
|
|
assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt")
|
|
|
|
assert not is_fitting or os.path.isfile(spawn_output.weights_path)
|
|
|
|
assert is_fitting or spawn_output.weights_path is None
|
|
|
|
else:
|
|
|
|
# all other ranks don't have outputs (rank 0 needs to handle the output)
|
|
|
|
assert spawn_output is None
|
2023-01-18 22:53:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
|
|
|
|
def test_transfer_weights(tmpdir, trainer_fn):
|
|
|
|
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
|
|
|
|
temporary file."""
|
|
|
|
model = Mock(wraps=BoringModel(), spec=BoringModel)
|
|
|
|
strategy = DDPSpawnStrategy()
|
|
|
|
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
|
|
|
|
trainer.strategy.connect(model)
|
|
|
|
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
|
|
|
|
|
|
|
|
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
|
|
|
|
|
|
|
|
model.state_dict.assert_called_once()
|
|
|
|
if trainer_fn == TrainerFn.FITTING:
|
|
|
|
assert spawn_output.weights_path.endswith(".temp.ckpt")
|
|
|
|
assert os.path.isfile(spawn_output.weights_path)
|
|
|
|
else:
|
|
|
|
assert spawn_output.weights_path is None
|
|
|
|
|
|
|
|
# <-- here would normally be the multiprocessing boundary
|
|
|
|
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
|
|
|
|
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
|
|
|
|
|
|
|
|
|
|
|
|
def test_non_strict_loading(tmpdir):
|
|
|
|
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
|
|
|
|
disabled, not erroring for missing keys."""
|
|
|
|
model = Mock(wraps=BoringModel(), spec=BoringModel)
|
|
|
|
strategy = DDPSpawnStrategy()
|
|
|
|
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
|
|
|
|
trainer.strategy.connect(model)
|
|
|
|
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case
|
|
|
|
|
|
|
|
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
|
|
|
|
# <-- here would normally be the multiprocessing boundary
|
|
|
|
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
|
|
|
|
model.load_state_dict.assert_called_once_with(ANY, strict=False)
|
2023-01-31 16:24:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_kill():
|
|
|
|
launcher = _MultiProcessingLauncher(Mock())
|
|
|
|
proc0 = Mock(autospec=Process)
|
|
|
|
proc1 = Mock(autospec=Process)
|
|
|
|
launcher.procs = [proc0, proc1]
|
|
|
|
|
|
|
|
with patch("os.kill") as kill_patch:
|
|
|
|
launcher.kill(15)
|
|
|
|
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]
|