From cc2ac02dd169c780d63d01e83a1d51149175e7a6 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Fri, 10 Sep 2021 13:58:02 -0700 Subject: [PATCH] Move add_to_queue/get_from_queue to DDPSpawnPlugin (#9118) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton Co-authored-by: ananthsub Co-authored-by: Adrian Wälchli Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 ++ pytorch_lightning/accelerators/accelerator.py | 5 +-- pytorch_lightning/core/lightning.py | 23 ++++++----- .../plugins/training_type/ddp.py | 3 +- .../plugins/training_type/ddp_spawn.py | 41 +++++++++++++++++-- .../training_type/training_type_plugin.py | 2 +- .../trainer/configuration_validator.py | 19 +++++++++ tests/deprecated_api/test_remove_1-7.py | 27 ++++++++++++ tests/plugins/test_ddp_spawn_plugin.py | 34 +++++++++++++-- 9 files changed, 136 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2988105ae..a902f874a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) +- Deprecated `add_to_queue`, `get_from_queue` from `LightningModule` in favor of corresponding methods in the `DDPSpawnPlugin` ([9118](https://github.com/PyTorchLightning/pytorch-lightning/pull/9118)) + + - Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 93915ac946..c82514aaa7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,9 +21,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin -from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -116,7 +115,7 @@ class Accelerator: def post_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something after the training/evaluation/prediction starts.""" - self.training_type_plugin.post_dispatch() + self.training_type_plugin.post_dispatch(trainer) self.precision_plugin.post_dispatch() @property diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b6a00e3168..6eec45465b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,13 +24,13 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union -import numpy as np import torch from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer from torchmetrics import Metric +import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import base as progress_base from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin @@ -1905,11 +1905,13 @@ class LightningModule( Args: queue: the instance of the queue to append the data. + + .. deprecated:: v1.5 + This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue` + and will be removed in v1.7. """ - callback_metrics: dict = apply_to_collection( - self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() - ) # send as numpy to avoid issues with memory sharing - queue.put(callback_metrics) + if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): + self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, @@ -1917,12 +1919,13 @@ class LightningModule( Args: queue: the instance of the queue from where to get the data. + + .. deprecated:: v1.5 + This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue` + and will be removed in v1.7. """ - # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.get() - self.trainer.callback_metrics.update( - apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)) - ) + if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): + self.trainer.training_type_plugin.get_from_queue(self.trainer, queue) @contextmanager def _prevent_trainer_and_dataloaders_deepcopy(self) -> None: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index d249e4dc76..647ff764f7 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,6 +29,7 @@ import torch import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -385,7 +386,7 @@ class DDPPlugin(ParallelPlugin): if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: "pl.Trainer") -> None: self.cluster_environment.teardown() def barrier(self, *args, **kwargs) -> None: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index f4ae970982..5f49300134 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -17,6 +17,7 @@ import re from multiprocessing.queues import SimpleQueue from typing import Any, Dict, List, Optional, Union +import numpy as np import torch import torch.distributed import torch.multiprocessing as mp @@ -36,6 +37,7 @@ from pytorch_lightning.utilities import ( rank_zero_deprecation, rank_zero_warn, ) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import ( @@ -45,6 +47,7 @@ from pytorch_lightning.utilities.distributed import ( ReduceOp, sync_ddp_if_available, ) +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -215,14 +218,18 @@ class DDPSpawnPlugin(ParallelPlugin): # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() - def post_dispatch(self): + def post_dispatch(self, trainer: "pl.Trainer"): # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() self._results = self.mp_queue.get() # get the `callback_metrics` and set it to the trainer # only in case the user does not override it. - self.lightning_module.get_from_queue(self.mp_queue) + # TODO: Remove the if in v1.7 + if is_overridden("get_from_queue", self.lightning_module): + self.lightning_module.get_from_queue(self.mp_queue) + else: + self.get_from_queue(trainer, self.mp_queue) # recover the weights of the processes trained in the children self.__recover_child_process_weights(best_path, last_path) @@ -288,7 +295,12 @@ class DDPSpawnPlugin(ParallelPlugin): self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + # adds the `callback_metrics` to the queue + # TODO: Remove the if in v1.7 + if is_overridden("add_to_queue", self.lightning_module): + self.lightning_module.add_to_queue(self.mp_queue) + else: + self.add_to_queue(trainer, self.mp_queue) def __recover_child_process_weights(self, best_path, last_path): # transfer back the best path to the trainer @@ -362,6 +374,29 @@ class DDPSpawnPlugin(ParallelPlugin): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True + def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory + sharing, we cast the data to numpy. + + Args: + queue: the instance of the queue to append the data. + """ + callback_metrics: dict = apply_to_collection( + trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() + ) # send as numpy to avoid issues with memory sharing + queue.put(callback_metrics) + + def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, + we cast back the data to ``torch.Tensor``. + + Args: + queue: the instance of the queue from where to get the data. + """ + # NOTE: `add_to_queue` needs to be called before + callback_metrics: dict = queue.get() + trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) + @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register( diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 29d9944a3f..2d18e5b51f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -360,5 +360,5 @@ class TrainingTypePlugin(ABC): def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something at trainer run_stage starts.""" - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something after the training/evaluation/prediction finishes.""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 603ec88ef4..ee5be467b8 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -43,6 +43,7 @@ class ConfigValidator: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + self._check_add_get_queue(model) # TODO(@daniellepintz): Delete _check_progress_bar in v1.7 self._check_progress_bar(model) # TODO: Delete _check_on_keyboard_interrupt in v1.7 @@ -219,6 +220,24 @@ class ConfigValidator: "is incompatible with `truncated_bptt_steps > 0`." ) + def _check_add_get_queue(self, model: "pl.LightningModule") -> None: + r""" + Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning. + + Args: + model: The lightning module + """ + if is_overridden("add_to_queue", model): + rank_zero_deprecation( + "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in " + "favor of `DDPSpawnPlugin.add_to_queue`" + ) + if is_overridden("get_from_queue", model): + rank_zero_deprecation( + "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in " + "favor of `DDPSpawnPlugin.get_from_queue`" + ) + def _check_on_keyboard_interrupt(self) -> None: """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" for callback in self.trainer.callbacks: diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 822e65bd2e..5dffd8501f 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -15,12 +15,14 @@ from unittest import mock import pytest +import torch from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): @@ -192,3 +194,28 @@ def test_v1_7_0_on_interrupt(tmpdir): def test_v1_7_0_process_position_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"): _ = Trainer(process_position=5) + + +class BoringCallbackDDPSpawnModel(BoringModel): + def __init__(self): + super().__init__() + + def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + queue.put("test_val") + return super().add_to_queue(queue) + + def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + self.test_val = queue.get() + return super().get_from_queue(queue) + + +@RunIf(skip_windows=True) +def test_v1_7_0_deprecate_add_get_queue(tmpdir): + model = BoringCallbackDDPSpawnModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu") + + with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"): + trainer.fit(model) + + with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"): + trainer.fit(model) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 2d987b0788..a89ddd3aaa 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -48,7 +48,7 @@ class BoringCallbackDDPSpawnModel(BoringModel): @RunIf(skip_windows=True) def test_ddp_cpu(): - """Tests if device is set correctely when training for DDPSpawnPlugin.""" + """Tests if device is set correctly when training for DDPSpawnPlugin.""" trainer = Trainer(num_processes=2, fast_dev_run=True) # assert training type plugin attributes for device setting @@ -64,7 +64,8 @@ def test_ddp_cpu(): @RunIf(min_gpus=2) def test_ddp_spawn_extra_parameters(tmpdir): - """Tests if device is set correctely when training for DDPSpawnPlugin.""" + """Tests if device is set correctly when training for DDPSpawnPlugin and tests add_to_queue/get_from_queue with + Lightning Module (deprecated way).""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) @@ -75,12 +76,39 @@ def test_ddp_spawn_extra_parameters(tmpdir): val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() - trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val" +class TestDDPSpawnPlugin(DDPSpawnPlugin): + def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: + queue.put("new_test_val") + return super().add_to_queue(trainer, queue) + + def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: + self.new_test_val = queue.get() + return super().get_from_queue(trainer, queue) + + +@RunIf(skip_windows=True) +def test_ddp_spawn_add_get_queue(tmpdir): + """Tests add_to_queue/get_from_queue with DDPSpawnPlugin.""" + + ddp_spawn_plugin = TestDDPSpawnPlugin() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu", plugins=[ddp_spawn_plugin] + ) + + val: float = 1.0 + val_name: str = "val_acc" + model = BoringCallbackDDPSpawnModel(val_name, val) + dm = BoringDataModule() + trainer.fit(model, datamodule=dm) + assert trainer.callback_metrics[val_name] == torch.tensor(val) + assert ddp_spawn_plugin.new_test_val == "new_test_val" + + class BoringModelDDP(BoringModel): def on_train_start(self) -> None: """Check if trainer module is wrapped as DistributedDataParallel during training stage."""