Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly (#9691)

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Update pytorch_lightning/distributed/dist.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

* Apply suggestions from code review

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Apply suggestions from code review

* Apply suggestions from code review

* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
four4fish 2021-09-25 08:39:15 -07:00 committed by GitHub
parent 444b21dc3d
commit ddf6967421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 43 additions and 31 deletions

View File

@ -274,6 +274,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616)) - Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))
- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
### Removed ### Removed
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
@ -402,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652)) - Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
## [1.4.8] - 2021-09-22 ## [1.4.8] - 2021-09-22
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) - Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))

View File

@ -14,11 +14,22 @@
from typing import Any from typing import Any
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import group as _group
class LightningDistributed: class LightningDistributed:
"""
.. deprecated:: v1.5
This class is deprecated in v1.5 and will be removed in v1.7.
The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes.
"""
def __init__(self, rank=None, device=None): def __init__(self, rank=None, device=None):
rank_zero_deprecation(
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
"Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations."
)
self.rank = rank self.rank = rank
self.device = device self.device = device

View File

@ -31,9 +31,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@ -48,13 +48,9 @@ from pytorch_lightning.utilities import (
rank_zero_deprecation, rank_zero_deprecation,
rank_zero_warn, rank_zero_warn,
) )
from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.distributed import distributed_available
distributed_available, from pytorch_lightning.utilities.distributed import group as _group
init_ddp_connection, from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -116,7 +112,6 @@ class DDPPlugin(ParallelPlugin):
" Notice that it will be overriden by the trainer setting." " Notice that it will be overriden by the trainer setting."
) )
self._sync_batchnorm = sync_batchnorm or False self._sync_batchnorm = sync_batchnorm or False
self.dist = LightningDistributed()
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs self._ddp_kwargs = kwargs
self._task_idx = None self._task_idx = None
@ -269,10 +264,6 @@ class DDPPlugin(ParallelPlugin):
# where to store ip_table # where to store ip_table
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
def _check_can_spawn_children(self): def _check_can_spawn_children(self):
if self.local_rank != 0: if self.local_rank != 0:
raise RuntimeError( raise RuntimeError(
@ -403,7 +394,11 @@ class DDPPlugin(ParallelPlugin):
torch.distributed.barrier() torch.distributed.barrier()
def broadcast(self, obj: object, src: int = 0) -> object: def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj) obj = [obj]
if self.global_rank != src:
obj = [None]
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
def pre_backward(self, closure_loss: torch.Tensor) -> None: def pre_backward(self, closure_loss: torch.Tensor) -> None:
"""Run before precision plugin executes backward.""" """Run before precision plugin executes backward."""

View File

@ -24,9 +24,9 @@ import torch.multiprocessing as mp
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@ -40,13 +40,9 @@ from pytorch_lightning.utilities import (
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.distributed import distributed_available
distributed_available, from pytorch_lightning.utilities.distributed import group as _group
init_ddp_connection, from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -93,7 +89,6 @@ class DDPSpawnPlugin(ParallelPlugin):
) )
self._sync_batchnorm = sync_batchnorm or False self._sync_batchnorm = sync_batchnorm or False
self._ddp_kwargs = kwargs self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None self.mp_queue = None
self._ddp_comm_state = ddp_comm_state self._ddp_comm_state = ddp_comm_state
@ -193,10 +188,6 @@ class DDPSpawnPlugin(ParallelPlugin):
# ... need to double check that it is the correct place # ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model) # self.trainer.call_setup_hook(self.model)
# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
# move the model to the correct device # move the model to the correct device
self.model_to_device() self.model_to_device()
@ -324,7 +315,11 @@ class DDPSpawnPlugin(ParallelPlugin):
def broadcast(self, obj: object, src: int = 0) -> object: def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available(): if not distributed_available():
return obj return obj
return self.dist.broadcast(obj) obj = [obj]
if self.global_rank != src:
obj = [None]
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
def model_to_device(self): def model_to_device(self):
if self.root_device.type == "cuda": if self.root_device.type == "cuda":

View File

@ -342,9 +342,6 @@ class DeepSpeedPlugin(DDPPlugin):
self._init_deepspeed_distributed() self._init_deepspeed_distributed()
# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
if not self._config_initialized: if not self._config_initialized:
self._format_config() self._format_config()
self._config_initialized = True self._config_initialized = True

View File

@ -46,6 +46,7 @@ omit =
pytorch_lightning/cluster_environments/*.py pytorch_lightning/cluster_environments/*.py
pytorch_lightning/utilities/distributed.py pytorch_lightning/utilities/distributed.py
pytorch_lightning/tuner/auto_gpu_select.py pytorch_lightning/tuner/auto_gpu_select.py
pytorch_lightning/distributed/dist.py
[flake8] [flake8]

View File

@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir):
): ):
logger = LoggerCollection([logger]) logger = LoggerCollection([logger])
logger.close() logger.close()
def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."):
from pytorch_lightning.distributed.dist import LightningDistributed
_ = LightningDistributed()