Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly (#9691)
* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Update pytorch_lightning/distributed/dist.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Apply suggestions from code review * Apply suggestions from code review * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
444b21dc3d
commit
ddf6967421
|
@ -274,6 +274,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))
|
||||
|
||||
|
||||
- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
||||
|
@ -402,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))
|
||||
|
||||
|
||||
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
|
||||
|
||||
|
||||
## [1.4.8] - 2021-09-22
|
||||
|
||||
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
|
||||
|
|
|
@ -14,11 +14,22 @@
|
|||
from typing import Any
|
||||
|
||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
|
||||
|
||||
class LightningDistributed:
|
||||
"""
|
||||
.. deprecated:: v1.5
|
||||
This class is deprecated in v1.5 and will be removed in v1.7.
|
||||
The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes.
|
||||
"""
|
||||
|
||||
def __init__(self, rank=None, device=None):
|
||||
rank_zero_deprecation(
|
||||
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
|
||||
"Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations."
|
||||
)
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
|
||||
|
|
|
@ -31,9 +31,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
|
@ -48,13 +48,9 @@ from pytorch_lightning.utilities import (
|
|||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
distributed_available,
|
||||
init_ddp_connection,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -116,7 +112,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
" Notice that it will be overriden by the trainer setting."
|
||||
)
|
||||
self._sync_batchnorm = sync_batchnorm or False
|
||||
self.dist = LightningDistributed()
|
||||
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||
self._ddp_kwargs = kwargs
|
||||
self._task_idx = None
|
||||
|
@ -269,10 +264,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
# where to store ip_table
|
||||
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
|
||||
|
||||
# set the ranks and devices
|
||||
self.dist.rank = self.global_rank
|
||||
self.dist.device = self.root_device
|
||||
|
||||
def _check_can_spawn_children(self):
|
||||
if self.local_rank != 0:
|
||||
raise RuntimeError(
|
||||
|
@ -403,7 +394,11 @@ class DDPPlugin(ParallelPlugin):
|
|||
torch.distributed.barrier()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
return self.dist.broadcast(obj)
|
||||
obj = [obj]
|
||||
if self.global_rank != src:
|
||||
obj = [None]
|
||||
broadcast_object_list(obj, src, group=_group.WORLD)
|
||||
return obj[0]
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
"""Run before precision plugin executes backward."""
|
||||
|
|
|
@ -24,9 +24,9 @@ import torch.multiprocessing as mp
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
|
@ -40,13 +40,9 @@ from pytorch_lightning.utilities import (
|
|||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
distributed_available,
|
||||
init_ddp_connection,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -93,7 +89,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
)
|
||||
self._sync_batchnorm = sync_batchnorm or False
|
||||
self._ddp_kwargs = kwargs
|
||||
self.dist = LightningDistributed()
|
||||
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
|
||||
self.mp_queue = None
|
||||
self._ddp_comm_state = ddp_comm_state
|
||||
|
@ -193,10 +188,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# ... need to double check that it is the correct place
|
||||
# self.trainer.call_setup_hook(self.model)
|
||||
|
||||
# set the ranks and devices
|
||||
self.dist.rank = self.global_rank
|
||||
self.dist.device = self.root_device
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
|
@ -324,7 +315,11 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
if not distributed_available():
|
||||
return obj
|
||||
return self.dist.broadcast(obj)
|
||||
obj = [obj]
|
||||
if self.global_rank != src:
|
||||
obj = [None]
|
||||
broadcast_object_list(obj, src, group=_group.WORLD)
|
||||
return obj[0]
|
||||
|
||||
def model_to_device(self):
|
||||
if self.root_device.type == "cuda":
|
||||
|
|
|
@ -342,9 +342,6 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
|
||||
self._init_deepspeed_distributed()
|
||||
|
||||
# set the ranks and devices
|
||||
self.dist.rank = self.global_rank
|
||||
self.dist.device = self.root_device
|
||||
if not self._config_initialized:
|
||||
self._format_config()
|
||||
self._config_initialized = True
|
||||
|
|
|
@ -46,6 +46,7 @@ omit =
|
|||
pytorch_lightning/cluster_environments/*.py
|
||||
pytorch_lightning/utilities/distributed.py
|
||||
pytorch_lightning/tuner/auto_gpu_select.py
|
||||
pytorch_lightning/distributed/dist.py
|
||||
|
||||
|
||||
[flake8]
|
||||
|
|
|
@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir):
|
|||
):
|
||||
logger = LoggerCollection([logger])
|
||||
logger.close()
|
||||
|
||||
|
||||
def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
|
||||
with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."):
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
|
||||
_ = LightningDistributed()
|
||||
|
|
Loading…
Reference in New Issue