Move add_to_queue/get_from_queue to DDPSpawnPlugin (#9118)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Danielle Pintz 2021-09-10 13:58:02 -07:00 committed by GitHub
parent 15434a9c35
commit cc2ac02dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 136 additions and 21 deletions

View File

@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
- Deprecated `add_to_queue`, `get_from_queue` from `LightningModule` in favor of corresponding methods in the `DDPSpawnPlugin` ([9118](https://github.com/PyTorchLightning/pytorch-lightning/pull/9118))
- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985))

View File

@ -21,9 +21,8 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
@ -116,7 +115,7 @@ class Accelerator:
def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.training_type_plugin.post_dispatch(trainer)
self.precision_plugin.post_dispatch()
@property

View File

@ -24,13 +24,13 @@ from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
@ -1905,11 +1905,13 @@ class LightningModule(
Args:
queue: the instance of the queue to append the data.
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue`
and will be removed in v1.7.
"""
callback_metrics: dict = apply_to_collection(
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
@ -1917,12 +1919,13 @@ class LightningModule(
Args:
queue: the instance of the queue from where to get the data.
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue`
and will be removed in v1.7.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
self.trainer.callback_metrics.update(
apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))
)
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.get_from_queue(self.trainer, queue)
@contextmanager
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:

View File

@ -29,6 +29,7 @@ import torch
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
@ -385,7 +386,7 @@ class DDPPlugin(ParallelPlugin):
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def post_dispatch(self) -> None:
def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()
def barrier(self, *args, **kwargs) -> None:

View File

@ -17,6 +17,7 @@ import re
from multiprocessing.queues import SimpleQueue
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.distributed
import torch.multiprocessing as mp
@ -36,6 +37,7 @@ from pytorch_lightning.utilities import (
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
@ -45,6 +47,7 @@ from pytorch_lightning.utilities.distributed import (
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -215,14 +218,18 @@ class DDPSpawnPlugin(ParallelPlugin):
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
def post_dispatch(self):
def post_dispatch(self, trainer: "pl.Trainer"):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.
self.lightning_module.get_from_queue(self.mp_queue)
# TODO: Remove the if in v1.7
if is_overridden("get_from_queue", self.lightning_module):
self.lightning_module.get_from_queue(self.mp_queue)
else:
self.get_from_queue(trainer, self.mp_queue)
# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
@ -288,7 +295,12 @@ class DDPSpawnPlugin(ParallelPlugin):
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)
def __recover_child_process_weights(self, best_path, last_path):
# transfer back the best path to the trainer
@ -362,6 +374,29 @@ class DDPSpawnPlugin(ParallelPlugin):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
Args:
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.
Args:
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register(

View File

@ -360,5 +360,5 @@ class TrainingTypePlugin(ABC):
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something at trainer run_stage starts."""
def post_dispatch(self) -> None:
def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

View File

@ -43,6 +43,7 @@ class ConfigValidator:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)
self._check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
self._check_progress_bar(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
@ -219,6 +220,24 @@ class ConfigValidator:
"is incompatible with `truncated_bptt_steps > 0`."
)
def _check_add_get_queue(self, model: "pl.LightningModule") -> None:
r"""
Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning.
Args:
model: The lightning module
"""
if is_overridden("add_to_queue", model):
rank_zero_deprecation(
"The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
"favor of `DDPSpawnPlugin.add_to_queue`"
)
if is_overridden("get_from_queue", model):
rank_zero_deprecation(
"The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
"favor of `DDPSpawnPlugin.get_from_queue`"
)
def _check_on_keyboard_interrupt(self) -> None:
"""Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
for callback in self.trainer.callbacks:

View File

@ -15,12 +15,14 @@
from unittest import mock
import pytest
import torch
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
@ -192,3 +194,28 @@ def test_v1_7_0_on_interrupt(tmpdir):
def test_v1_7_0_process_position_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"):
_ = Trainer(process_position=5)
class BoringCallbackDDPSpawnModel(BoringModel):
def __init__(self):
super().__init__()
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
queue.put("test_val")
return super().add_to_queue(queue)
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
self.test_val = queue.get()
return super().get_from_queue(queue)
@RunIf(skip_windows=True)
def test_v1_7_0_deprecate_add_get_queue(tmpdir):
model = BoringCallbackDDPSpawnModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu")
with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"):
trainer.fit(model)
with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"):
trainer.fit(model)

View File

@ -48,7 +48,7 @@ class BoringCallbackDDPSpawnModel(BoringModel):
@RunIf(skip_windows=True)
def test_ddp_cpu():
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
trainer = Trainer(num_processes=2, fast_dev_run=True)
# assert training type plugin attributes for device setting
@ -64,7 +64,8 @@ def test_ddp_cpu():
@RunIf(min_gpus=2)
def test_ddp_spawn_extra_parameters(tmpdir):
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
"""Tests if device is set correctly when training for DDPSpawnPlugin and tests add_to_queue/get_from_queue with
Lightning Module (deprecated way)."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
@ -75,12 +76,39 @@ def test_ddp_spawn_extra_parameters(tmpdir):
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"
class TestDDPSpawnPlugin(DDPSpawnPlugin):
def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
queue.put("new_test_val")
return super().add_to_queue(trainer, queue)
def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
self.new_test_val = queue.get()
return super().get_from_queue(trainer, queue)
@RunIf(skip_windows=True)
def test_ddp_spawn_add_get_queue(tmpdir):
"""Tests add_to_queue/get_from_queue with DDPSpawnPlugin."""
ddp_spawn_plugin = TestDDPSpawnPlugin()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu", plugins=[ddp_spawn_plugin]
)
val: float = 1.0
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert ddp_spawn_plugin.new_test_val == "new_test_val"
class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""