From ddf69674211f144fefa530301edf2fc17be5c70f Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Sat, 25 Sep 2021 08:39:15 -0700 Subject: [PATCH] Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly (#9691) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: Adrian Wälchli * Apply suggestions from code review Co-authored-by: ananthsub * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Update pytorch_lightning/distributed/dist.py Co-authored-by: ananthsub * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: ananthsub * Apply suggestions from code review * Apply suggestions from code review * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly Co-authored-by: Adrian Wälchli Co-authored-by: ananthsub --- CHANGELOG.md | 6 +++++ pytorch_lightning/distributed/dist.py | 11 +++++++++ .../plugins/training_type/ddp.py | 23 ++++++++----------- .../plugins/training_type/ddp_spawn.py | 23 ++++++++----------- .../plugins/training_type/deepspeed.py | 3 --- setup.cfg | 1 + tests/deprecated_api/test_remove_1-7.py | 7 ++++++ 7 files changed, 43 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f66a3f69b0..05e148770b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -274,6 +274,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616)) +- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) @@ -402,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652)) +- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) + + ## [1.4.8] - 2021-09-22 - Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index d4e41f6e7c..082e0c617a 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -14,11 +14,22 @@ from typing import Any from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.distributed import group as _group class LightningDistributed: + """ + .. deprecated:: v1.5 + This class is deprecated in v1.5 and will be removed in v1.7. + The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes. + """ + def __init__(self, rank=None, device=None): + rank_zero_deprecation( + "LightningDistributed is deprecated in v1.5 and will be removed in v1.7." + "Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + ) self.rank = rank self.device = device diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index df0f658bf7..a26b63151f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,9 +31,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -48,13 +48,9 @@ from pytorch_lightning.utilities import ( rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -116,7 +112,6 @@ class DDPPlugin(ParallelPlugin): " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False - self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None @@ -269,10 +264,6 @@ class DDPPlugin(ParallelPlugin): # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - def _check_can_spawn_children(self): if self.local_rank != 0: raise RuntimeError( @@ -403,7 +394,11 @@ class DDPPlugin(ParallelPlugin): torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) + obj = [obj] + if self.global_rank != src: + obj = [None] + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f49300134..eb1acaec41 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,9 +24,9 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -40,13 +40,9 @@ from pytorch_lightning.utilities import ( from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -93,7 +89,6 @@ class DDPSpawnPlugin(ParallelPlugin): ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs - self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -193,10 +188,6 @@ class DDPSpawnPlugin(ParallelPlugin): # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - # move the model to the correct device self.model_to_device() @@ -324,7 +315,11 @@ class DDPSpawnPlugin(ParallelPlugin): def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj - return self.dist.broadcast(obj) + obj = [obj] + if self.global_rank != src: + obj = [None] + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def model_to_device(self): if self.root_device.type == "cuda": diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index cb3b007b71..978152506d 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -342,9 +342,6 @@ class DeepSpeedPlugin(DDPPlugin): self._init_deepspeed_distributed() - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device if not self._config_initialized: self._format_config() self._config_initialized = True diff --git a/setup.cfg b/setup.cfg index 86890f08e2..99f3a513b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ omit = pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py + pytorch_lightning/distributed/dist.py [flake8] diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index de4dba76b9..44c6df1441 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir): ): logger = LoggerCollection([logger]) logger.close() + + +def test_v1_7_0_deprecate_lightning_distributed(tmpdir): + with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."): + from pytorch_lightning.distributed.dist import LightningDistributed + + _ = LightningDistributed()