From 7767fd36b68b956ac5f81c713b9384e253f983aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Nov 2022 13:07:58 +0100 Subject: [PATCH] Fix result transfer in multiprocessing launcher on multi-node (#15567) * Fix result transfer in multiprocessing launcher on multi-node * add simple test * add comment * update test * changelog * use tempfile * fix * assert None * unused import * add comment --- src/pytorch_lightning/CHANGELOG.md | 2 + .../strategies/launchers/multiprocessing.py | 6 ++- .../launchers/test_multiprocessing.py | 48 +++++++++++++++++++ .../strategies/test_ddp_spawn_strategy.py | 10 ++-- 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 7189b7b91c..3ebfa6c2a5 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -61,6 +61,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544)) +- Fixed model state transfer in multiprocessing launcher when running multi-node ([#15567](https://github.com/Lightning-AI/lightning/pull/15567)) + ## [1.8.0] - 2022-11-01 diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index fc9723b365..1f225d749d 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tempfile from collections import UserList from dataclasses import dataclass from multiprocessing.queues import SimpleQueue @@ -172,13 +173,14 @@ class _MultiProcessingLauncher(_Launcher): # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() - if self._strategy.global_rank != 0: + if self._strategy.local_rank != 0: return None # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + # use tempdir here to avoid race conditions because the filesystem may be shared between nodes + weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt") self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index ef1a5ccce1..142f6c53d2 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,13 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from unittest import mock from unittest.mock import ANY, Mock import pytest import torch +from lightning_lite.plugins import ClusterEnvironment +from pytorch_lightning import Trainer +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf @@ -76,3 +82,45 @@ def test_global_state_snapshot(): assert torch.are_deterministic_algorithms_enabled() assert not torch.backends.cudnn.benchmark assert torch.initial_seed() == 123 + + +@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) +@pytest.mark.parametrize("fake_node_rank", [0, 1]) +@pytest.mark.parametrize("fake_local_rank", [0, 1]) +def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir): + """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary + file.""" + model = Mock(wraps=BoringModel(), spec=BoringModel) + fake_global_rank = 2 * fake_node_rank + fake_local_rank + + cluster_environment = Mock(spec=ClusterEnvironment) + cluster_environment.world_size.return_value = 4 + cluster_environment.node_rank.return_value = fake_node_rank + cluster_environment.local_rank.return_value = fake_local_rank + cluster_environment.global_rank.return_value = fake_global_rank + + strategy = DDPSpawnStrategy(cluster_environment=cluster_environment) + strategy._local_rank = fake_local_rank + + launcher = _MultiProcessingLauncher(strategy=strategy) + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) + + assert strategy.node_rank == fake_node_rank + assert strategy.local_rank == fake_local_rank + assert strategy.global_rank == fake_global_rank + + trainer.strategy.connect(model) + trainer.state.fn = trainer_fn # pretend we are in a particular trainer state + + spawn_output = launcher._collect_rank_zero_results(trainer, {}) + + model.state_dict.assert_called_once() + is_fitting = trainer_fn == TrainerFn.FITTING + if strategy.local_rank == 0: + # on local rank 0 (each node), we expect a temp checkpoint (when fitting) + assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt") + assert not is_fitting or os.path.isfile(spawn_output.weights_path) + assert is_fitting or spawn_output.weights_path is None + else: + # all other ranks don't have outputs (rank 0 needs to handle the output) + assert spawn_output is None diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 7c1d347970..22a30b927b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from datetime import timedelta -from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -135,23 +135,19 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state - temp_file = Path(tmpdir, ".temp.ckpt") - assert not temp_file.exists() spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: - assert spawn_output.weights_path == str(temp_file) - assert temp_file.exists() + assert spawn_output.weights_path.endswith(".temp.ckpt") + assert os.path.isfile(spawn_output.weights_path) else: assert spawn_output.weights_path is None - assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary strategy._launcher._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) - assert not temp_file.exists() @mock.patch("torch.distributed.init_process_group")