Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly (#9691)
* Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Update pytorch_lightning/distributed/dist.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly * Apply suggestions from code review Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Apply suggestions from code review * Apply suggestions from code review * Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
444b21dc3d
commit
ddf6967421
|
@ -274,6 +274,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))
|
- Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks, or passing `enable_progress_bar=False` to disable the progress bar ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))
|
||||||
|
|
||||||
|
|
||||||
|
- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
|
||||||
|
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
|
|
||||||
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
||||||
|
@ -402,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))
|
- Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
|
||||||
|
|
||||||
|
|
||||||
## [1.4.8] - 2021-09-22
|
## [1.4.8] - 2021-09-22
|
||||||
|
|
||||||
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
|
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
|
||||||
|
|
|
@ -14,11 +14,22 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||||
|
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||||
from pytorch_lightning.utilities.distributed import group as _group
|
from pytorch_lightning.utilities.distributed import group as _group
|
||||||
|
|
||||||
|
|
||||||
class LightningDistributed:
|
class LightningDistributed:
|
||||||
|
"""
|
||||||
|
.. deprecated:: v1.5
|
||||||
|
This class is deprecated in v1.5 and will be removed in v1.7.
|
||||||
|
The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, rank=None, device=None):
|
def __init__(self, rank=None, device=None):
|
||||||
|
rank_zero_deprecation(
|
||||||
|
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
|
||||||
|
"Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations."
|
||||||
|
)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
|
@ -31,9 +31,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||||
from pytorch_lightning.distributed import LightningDistributed
|
|
||||||
from pytorch_lightning.overrides import LightningDistributedModule
|
from pytorch_lightning.overrides import LightningDistributedModule
|
||||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||||
|
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||||
|
@ -48,13 +48,9 @@ from pytorch_lightning.utilities import (
|
||||||
rank_zero_deprecation,
|
rank_zero_deprecation,
|
||||||
rank_zero_warn,
|
rank_zero_warn,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.distributed import (
|
from pytorch_lightning.utilities.distributed import distributed_available
|
||||||
distributed_available,
|
from pytorch_lightning.utilities.distributed import group as _group
|
||||||
init_ddp_connection,
|
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||||
rank_zero_only,
|
|
||||||
ReduceOp,
|
|
||||||
sync_ddp_if_available,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||||
from pytorch_lightning.utilities.seed import reset_seed
|
from pytorch_lightning.utilities.seed import reset_seed
|
||||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||||
|
@ -116,7 +112,6 @@ class DDPPlugin(ParallelPlugin):
|
||||||
" Notice that it will be overriden by the trainer setting."
|
" Notice that it will be overriden by the trainer setting."
|
||||||
)
|
)
|
||||||
self._sync_batchnorm = sync_batchnorm or False
|
self._sync_batchnorm = sync_batchnorm or False
|
||||||
self.dist = LightningDistributed()
|
|
||||||
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
|
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||||
self._ddp_kwargs = kwargs
|
self._ddp_kwargs = kwargs
|
||||||
self._task_idx = None
|
self._task_idx = None
|
||||||
|
@ -269,10 +264,6 @@ class DDPPlugin(ParallelPlugin):
|
||||||
# where to store ip_table
|
# where to store ip_table
|
||||||
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
|
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
|
||||||
|
|
||||||
# set the ranks and devices
|
|
||||||
self.dist.rank = self.global_rank
|
|
||||||
self.dist.device = self.root_device
|
|
||||||
|
|
||||||
def _check_can_spawn_children(self):
|
def _check_can_spawn_children(self):
|
||||||
if self.local_rank != 0:
|
if self.local_rank != 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -403,7 +394,11 @@ class DDPPlugin(ParallelPlugin):
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||||
return self.dist.broadcast(obj)
|
obj = [obj]
|
||||||
|
if self.global_rank != src:
|
||||||
|
obj = [None]
|
||||||
|
broadcast_object_list(obj, src, group=_group.WORLD)
|
||||||
|
return obj[0]
|
||||||
|
|
||||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||||
"""Run before precision plugin executes backward."""
|
"""Run before precision plugin executes backward."""
|
||||||
|
|
|
@ -24,9 +24,9 @@ import torch.multiprocessing as mp
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
|
||||||
from pytorch_lightning.overrides import LightningDistributedModule
|
from pytorch_lightning.overrides import LightningDistributedModule
|
||||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||||
|
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||||
|
@ -40,13 +40,9 @@ from pytorch_lightning.utilities import (
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from pytorch_lightning.utilities.distributed import (
|
from pytorch_lightning.utilities.distributed import distributed_available
|
||||||
distributed_available,
|
from pytorch_lightning.utilities.distributed import group as _group
|
||||||
init_ddp_connection,
|
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||||
rank_zero_only,
|
|
||||||
ReduceOp,
|
|
||||||
sync_ddp_if_available,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||||
from pytorch_lightning.utilities.seed import reset_seed
|
from pytorch_lightning.utilities.seed import reset_seed
|
||||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||||
|
@ -93,7 +89,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||||
)
|
)
|
||||||
self._sync_batchnorm = sync_batchnorm or False
|
self._sync_batchnorm = sync_batchnorm or False
|
||||||
self._ddp_kwargs = kwargs
|
self._ddp_kwargs = kwargs
|
||||||
self.dist = LightningDistributed()
|
|
||||||
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
|
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
|
||||||
self.mp_queue = None
|
self.mp_queue = None
|
||||||
self._ddp_comm_state = ddp_comm_state
|
self._ddp_comm_state = ddp_comm_state
|
||||||
|
@ -193,10 +188,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||||
# ... need to double check that it is the correct place
|
# ... need to double check that it is the correct place
|
||||||
# self.trainer.call_setup_hook(self.model)
|
# self.trainer.call_setup_hook(self.model)
|
||||||
|
|
||||||
# set the ranks and devices
|
|
||||||
self.dist.rank = self.global_rank
|
|
||||||
self.dist.device = self.root_device
|
|
||||||
|
|
||||||
# move the model to the correct device
|
# move the model to the correct device
|
||||||
self.model_to_device()
|
self.model_to_device()
|
||||||
|
|
||||||
|
@ -324,7 +315,11 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||||
if not distributed_available():
|
if not distributed_available():
|
||||||
return obj
|
return obj
|
||||||
return self.dist.broadcast(obj)
|
obj = [obj]
|
||||||
|
if self.global_rank != src:
|
||||||
|
obj = [None]
|
||||||
|
broadcast_object_list(obj, src, group=_group.WORLD)
|
||||||
|
return obj[0]
|
||||||
|
|
||||||
def model_to_device(self):
|
def model_to_device(self):
|
||||||
if self.root_device.type == "cuda":
|
if self.root_device.type == "cuda":
|
||||||
|
|
|
@ -342,9 +342,6 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
|
|
||||||
self._init_deepspeed_distributed()
|
self._init_deepspeed_distributed()
|
||||||
|
|
||||||
# set the ranks and devices
|
|
||||||
self.dist.rank = self.global_rank
|
|
||||||
self.dist.device = self.root_device
|
|
||||||
if not self._config_initialized:
|
if not self._config_initialized:
|
||||||
self._format_config()
|
self._format_config()
|
||||||
self._config_initialized = True
|
self._config_initialized = True
|
||||||
|
|
|
@ -46,6 +46,7 @@ omit =
|
||||||
pytorch_lightning/cluster_environments/*.py
|
pytorch_lightning/cluster_environments/*.py
|
||||||
pytorch_lightning/utilities/distributed.py
|
pytorch_lightning/utilities/distributed.py
|
||||||
pytorch_lightning/tuner/auto_gpu_select.py
|
pytorch_lightning/tuner/auto_gpu_select.py
|
||||||
|
pytorch_lightning/distributed/dist.py
|
||||||
|
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
|
|
|
@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir):
|
||||||
):
|
):
|
||||||
logger = LoggerCollection([logger])
|
logger = LoggerCollection([logger])
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
|
||||||
|
with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."):
|
||||||
|
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||||
|
|
||||||
|
_ = LightningDistributed()
|
||||||
|
|
Loading…
Reference in New Issue