Move add_to_queue/get_from_queue to DDPSpawnPlugin (#9118)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
15434a9c35
commit
cc2ac02dd1
|
@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
|
||||
|
||||
|
||||
- Deprecated `add_to_queue`, `get_from_queue` from `LightningModule` in favor of corresponding methods in the `DDPSpawnPlugin` ([9118](https://github.com/PyTorchLightning/pytorch-lightning/pull/9118))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985))
|
||||
|
||||
|
||||
|
|
|
@ -21,9 +21,8 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins import DataParallelPlugin
|
||||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
|
||||
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
|
@ -116,7 +115,7 @@ class Accelerator:
|
|||
|
||||
def post_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.post_dispatch()
|
||||
self.training_type_plugin.post_dispatch(trainer)
|
||||
self.precision_plugin.post_dispatch()
|
||||
|
||||
@property
|
||||
|
|
|
@ -24,13 +24,13 @@ from contextlib import contextmanager
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import ScriptModule, Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torchmetrics import Metric
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress import base as progress_base
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
|
||||
|
@ -1905,11 +1905,13 @@ class LightningModule(
|
|||
|
||||
Args:
|
||||
queue: the instance of the queue to append the data.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue`
|
||||
and will be removed in v1.7.
|
||||
"""
|
||||
callback_metrics: dict = apply_to_collection(
|
||||
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
|
||||
) # send as numpy to avoid issues with memory sharing
|
||||
queue.put(callback_metrics)
|
||||
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
|
||||
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
|
||||
|
||||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
|
||||
|
@ -1917,12 +1919,13 @@ class LightningModule(
|
|||
|
||||
Args:
|
||||
queue: the instance of the queue from where to get the data.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue`
|
||||
and will be removed in v1.7.
|
||||
"""
|
||||
# NOTE: `add_to_queue` needs to be called before
|
||||
callback_metrics: dict = queue.get()
|
||||
self.trainer.callback_metrics.update(
|
||||
apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))
|
||||
)
|
||||
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
|
||||
self.trainer.training_type_plugin.get_from_queue(self.trainer, queue)
|
||||
|
||||
@contextmanager
|
||||
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
|
||||
|
|
|
@ -29,6 +29,7 @@ import torch
|
|||
import torch.distributed
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
|
@ -385,7 +386,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
def post_dispatch(self) -> None:
|
||||
def post_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
self.cluster_environment.teardown()
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
|
|
|
@ -17,6 +17,7 @@ import re
|
|||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -36,6 +37,7 @@ from pytorch_lightning.utilities import (
|
|||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
|
@ -45,6 +47,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
@ -215,14 +218,18 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
|
||||
def post_dispatch(self):
|
||||
def post_dispatch(self, trainer: "pl.Trainer"):
|
||||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
last_path = self.mp_queue.get()
|
||||
self._results = self.mp_queue.get()
|
||||
# get the `callback_metrics` and set it to the trainer
|
||||
# only in case the user does not override it.
|
||||
self.lightning_module.get_from_queue(self.mp_queue)
|
||||
# TODO: Remove the if in v1.7
|
||||
if is_overridden("get_from_queue", self.lightning_module):
|
||||
self.lightning_module.get_from_queue(self.mp_queue)
|
||||
else:
|
||||
self.get_from_queue(trainer, self.mp_queue)
|
||||
|
||||
# recover the weights of the processes trained in the children
|
||||
self.__recover_child_process_weights(best_path, last_path)
|
||||
|
@ -288,7 +295,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.mp_queue.put(best_model_path)
|
||||
self.mp_queue.put(last_path)
|
||||
self.mp_queue.put(results)
|
||||
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
|
||||
# adds the `callback_metrics` to the queue
|
||||
# TODO: Remove the if in v1.7
|
||||
if is_overridden("add_to_queue", self.lightning_module):
|
||||
self.lightning_module.add_to_queue(self.mp_queue)
|
||||
else:
|
||||
self.add_to_queue(trainer, self.mp_queue)
|
||||
|
||||
def __recover_child_process_weights(self, best_path, last_path):
|
||||
# transfer back the best path to the trainer
|
||||
|
@ -362,6 +374,29 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
if not self.lightning_module.automatic_optimization:
|
||||
self.model.require_backward_grad_sync = True
|
||||
|
||||
def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
|
||||
sharing, we cast the data to numpy.
|
||||
|
||||
Args:
|
||||
queue: the instance of the queue to append the data.
|
||||
"""
|
||||
callback_metrics: dict = apply_to_collection(
|
||||
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
|
||||
) # send as numpy to avoid issues with memory sharing
|
||||
queue.put(callback_metrics)
|
||||
|
||||
def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
|
||||
we cast back the data to ``torch.Tensor``.
|
||||
|
||||
Args:
|
||||
queue: the instance of the queue from where to get the data.
|
||||
"""
|
||||
# NOTE: `add_to_queue` needs to be called before
|
||||
callback_metrics: dict = queue.get()
|
||||
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
|
||||
|
||||
@classmethod
|
||||
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||
plugin_registry.register(
|
||||
|
|
|
@ -360,5 +360,5 @@ class TrainingTypePlugin(ABC):
|
|||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something at trainer run_stage starts."""
|
||||
|
||||
def post_dispatch(self) -> None:
|
||||
def post_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction finishes."""
|
||||
|
|
|
@ -43,6 +43,7 @@ class ConfigValidator:
|
|||
elif self.trainer.state.fn == TrainerFn.PREDICTING:
|
||||
self.__verify_predict_loop_configuration(model)
|
||||
self.__verify_dp_batch_transfer_support(model)
|
||||
self._check_add_get_queue(model)
|
||||
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
|
||||
self._check_progress_bar(model)
|
||||
# TODO: Delete _check_on_keyboard_interrupt in v1.7
|
||||
|
@ -219,6 +220,24 @@ class ConfigValidator:
|
|||
"is incompatible with `truncated_bptt_steps > 0`."
|
||||
)
|
||||
|
||||
def _check_add_get_queue(self, model: "pl.LightningModule") -> None:
|
||||
r"""
|
||||
Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning.
|
||||
|
||||
Args:
|
||||
model: The lightning module
|
||||
"""
|
||||
if is_overridden("add_to_queue", model):
|
||||
rank_zero_deprecation(
|
||||
"The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
|
||||
"favor of `DDPSpawnPlugin.add_to_queue`"
|
||||
)
|
||||
if is_overridden("get_from_queue", model):
|
||||
rank_zero_deprecation(
|
||||
"The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
|
||||
"favor of `DDPSpawnPlugin.get_from_queue`"
|
||||
)
|
||||
|
||||
def _check_on_keyboard_interrupt(self) -> None:
|
||||
"""Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
|
||||
for callback in self.trainer.callbacks:
|
||||
|
|
|
@ -15,12 +15,14 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||
from pytorch_lightning.loggers import TestTubeLogger
|
||||
from tests.deprecated_api import _soft_unimport_module
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.datamodules import MNISTDataModule
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
|
||||
|
@ -192,3 +194,28 @@ def test_v1_7_0_on_interrupt(tmpdir):
|
|||
def test_v1_7_0_process_position_trainer_constructor(tmpdir):
|
||||
with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"):
|
||||
_ = Trainer(process_position=5)
|
||||
|
||||
|
||||
class BoringCallbackDDPSpawnModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
queue.put("test_val")
|
||||
return super().add_to_queue(queue)
|
||||
|
||||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
self.test_val = queue.get()
|
||||
return super().get_from_queue(queue)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_v1_7_0_deprecate_add_get_queue(tmpdir):
|
||||
model = BoringCallbackDDPSpawnModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu")
|
||||
|
||||
with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
||||
with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -48,7 +48,7 @@ class BoringCallbackDDPSpawnModel(BoringModel):
|
|||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_ddp_cpu():
|
||||
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
|
||||
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
|
||||
trainer = Trainer(num_processes=2, fast_dev_run=True)
|
||||
# assert training type plugin attributes for device setting
|
||||
|
||||
|
@ -64,7 +64,8 @@ def test_ddp_cpu():
|
|||
|
||||
@RunIf(min_gpus=2)
|
||||
def test_ddp_spawn_extra_parameters(tmpdir):
|
||||
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
|
||||
"""Tests if device is set correctly when training for DDPSpawnPlugin and tests add_to_queue/get_from_queue with
|
||||
Lightning Module (deprecated way)."""
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")
|
||||
|
||||
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
|
||||
|
@ -75,12 +76,39 @@ def test_ddp_spawn_extra_parameters(tmpdir):
|
|||
val_name: str = "val_acc"
|
||||
model = BoringCallbackDDPSpawnModel(val_name, val)
|
||||
dm = BoringDataModule()
|
||||
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.callback_metrics[val_name] == torch.tensor(val)
|
||||
assert model.test_val == "test_val"
|
||||
|
||||
|
||||
class TestDDPSpawnPlugin(DDPSpawnPlugin):
|
||||
def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
queue.put("new_test_val")
|
||||
return super().add_to_queue(trainer, queue)
|
||||
|
||||
def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
self.new_test_val = queue.get()
|
||||
return super().get_from_queue(trainer, queue)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_ddp_spawn_add_get_queue(tmpdir):
|
||||
"""Tests add_to_queue/get_from_queue with DDPSpawnPlugin."""
|
||||
|
||||
ddp_spawn_plugin = TestDDPSpawnPlugin()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu", plugins=[ddp_spawn_plugin]
|
||||
)
|
||||
|
||||
val: float = 1.0
|
||||
val_name: str = "val_acc"
|
||||
model = BoringCallbackDDPSpawnModel(val_name, val)
|
||||
dm = BoringDataModule()
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.callback_metrics[val_name] == torch.tensor(val)
|
||||
assert ddp_spawn_plugin.new_test_val == "new_test_val"
|
||||
|
||||
|
||||
class BoringModelDDP(BoringModel):
|
||||
def on_train_start(self) -> None:
|
||||
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
|
||||
|
|
Loading…
Reference in New Issue