diff --git a/CHANGELOG.md b/CHANGELOG.md index ed9b2d1586..d9cab7a63e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) +- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736)) - Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index b8437b0d41..5f411b65ae 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,7 +29,12 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities import ( + _HYDRA_AVAILABLE, + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, + rank_zero_warn, +) from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -37,6 +42,8 @@ from pytorch_lightning.utilities.seed import seed_everything if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path +if _TORCH_GREATER_EQUAL_1_8: + from pytorch_lightning.utilities.distributed import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -58,6 +65,9 @@ class DDPPlugin(ParallelPlugin): num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) @@ -70,6 +80,9 @@ class DDPPlugin(ParallelPlugin): self.task_idx = None self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices + self._ddp_comm_state = ddp_comm_state + self._ddp_comm_hook = ddp_comm_hook + self._ddp_comm_wrapper = ddp_comm_wrapper @property def root_device(self): @@ -80,6 +93,10 @@ class DDPPlugin(ParallelPlugin): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs + @property + def _is_single_process_single_device(self) -> bool: + return True + def setup_environment(self): # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": @@ -218,6 +235,17 @@ class DDPPlugin(ParallelPlugin): ) self._ddp_kwargs["find_unused_parameters"] = True + def _register_ddp_hooks(self) -> None: + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device): + register_ddp_comm_hook( + model=self._model, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( @@ -225,6 +253,7 @@ class DDPPlugin(ParallelPlugin): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) + self._register_ddp_hooks() def determine_ddp_device_ids(self): if self.root_device.type == "cpu": diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index a94bb5459b..f19fb05a16 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -59,6 +59,10 @@ class DDP2Plugin(DDPPlugin): distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank) return distributed_sampler_kwargs + @property + def _is_single_process_single_device(self) -> bool: + return False + def set_world_ranks(self): self.local_rank = self.task_idx self.node_rank = self.cluster_environment.node_rank() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 126afc9be6..e902872934 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -28,12 +28,15 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything +if _TORCH_GREATER_EQUAL_1_8: + from pytorch_lightning.utilities.distributed import register_ddp_comm_hook + log = logging.getLogger(__name__) @@ -47,6 +50,9 @@ class DDPSpawnPlugin(ParallelPlugin): num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) @@ -54,9 +60,12 @@ class DDPSpawnPlugin(ParallelPlugin): self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) + self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.node_rank = 0 self.mp_queue = None + self._ddp_comm_state = ddp_comm_state + self._ddp_comm_hook = ddp_comm_hook + self._ddp_comm_wrapper = ddp_comm_wrapper def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ @@ -76,9 +85,12 @@ class DDPSpawnPlugin(ParallelPlugin): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs + @property + def _is_single_process_single_device(self): + return True + def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() @@ -181,6 +193,17 @@ class DDPSpawnPlugin(ParallelPlugin): ) self._ddp_kwargs["find_unused_parameters"] = True + def _register_ddp_hooks(self) -> None: + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device): + register_ddp_comm_hook( + model=self._model, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( @@ -188,6 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) + self._register_ddp_hooks() def init_ddp_connection(self, global_rank: int, world_size: int) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a6d549e382..398e3782be 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -47,6 +47,8 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _RPC_AVAILABLE, _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_9, _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index b4793889f1..018d83a93a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -17,9 +17,15 @@ import os import warnings from functools import partial, wraps from typing import Any, Optional, Union +from pytorch_lightning.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_9, +) import torch +from torch.nn.parallel.distributed import DistributedDataParallel + log = logging.getLogger(__name__) if torch.distributed.is_available(): @@ -208,3 +214,107 @@ def all_gather_ddp_if_available( with torch.no_grad(): return AllGatherGrad.apply(tensor, group) return tensor + + +def register_ddp_comm_hook( + model: DistributedDataParallel, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, +) -> None: + """ + Function to register communication hook for DDP model + https://pytorch.org/docs/master/ddp_comm_hooks.html + + Args: + model: + DDP model + ddp_comm_state: + state is passed to the hook and can be used to maintain + and update any state information that users would like to + maintain as part of the training process. Examples: error + feedback in gradient compression, peers to communicate with + next in GossipGrad etc. + ddp_comm_hook: + hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future + + This callable function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it can also + just return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + ddp_comm_wrapper: + communication hook wraper to support a communication hook such + as FP16 compression as wrapper, which could be combined with + ddp_comm_hook + + .. warning :: + DDP communication hook needs pytorch version at least 1.8.0 + + .. warning :: + DDP communication wrapper needs pytorch version at least 1.9.0 + + Example: + + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, + ) + + # fp16_compress_hook for compress gradients + register_ddp_comm_hook( + model=ddp_model, + ddp_comm_hook=default.fp16_compress_hook, + ) + + # powerSGD_hook + register_ddp_comm_hook( + model=ddp_model, + ddp_comm_state=powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=5000, + ), + ddp_comm_hook=powerSGD.powerSGD_hook, + ) + + # fp16_compress_wrapper combined with other communication hook + register_ddp_comm_hook( + model=ddp_model, + ddp_comm_state=powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=5000, + ), + ddp_comm_hook=powerSGD.powerSGD_hook, + ddp_comm_wrapper=default.fp16_compress_wrapper, + ) + """ + if not _TORCH_GREATER_EQUAL_1_8: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, please use pytorch>=1.8.0." + ) + return + if ddp_comm_hook is None: + return + if ddp_comm_wrapper is not None: + if not _TORCH_GREATER_EQUAL_1_9: + rank_zero_warn( + "Not applying DDP comm wrapper. " + "To use communication wrapper, please use pytorch>=1.9.0." + ) + else: + rank_zero_info( + f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." + ) + ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) + + rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") + model.register_comm_hook( + state=ddp_comm_state, + hook=ddp_comm_hook, + ) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 001b9a67c5..621e0d17d2 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -69,6 +69,7 @@ _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 872b49ef48..6b04ac7b38 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -18,7 +18,6 @@ from tests.helpers.runif import RunIf class CustomParallelPlugin(DDPPlugin): - def __init__(self, **kwargs): super().__init__(**kwargs) # Set to None so it will be overwritten by the accelerator connector. diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py new file mode 100644 index 0000000000..25845de1ae --- /dev/null +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -0,0 +1,134 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + +if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, + ) + + +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) +def test_ddp_fp16_compress_comm_hook(tmpdir): + """Test for DDP FP16 compress hook.""" + model = BoringModel() + training_type_plugin = DDPPlugin( + ddp_comm_hook=default.fp16_compress_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = default.fp16_compress_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) +def test_ddp_sgd_comm_hook(tmpdir): + """Test for DDP FP16 compress hook.""" + model = BoringModel() + training_type_plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) +def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): + """Test for DDP FP16 compress wrapper for SGD hook.""" + model = BoringModel() + training_type_plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + ddp_comm_wrapper=default.fp16_compress_wrapper, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = default.fp16_compress_wrapper( + powerSGD.powerSGD_hook + ).__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) +def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): + """Test for DDP Spawn FP16 compress hook.""" + model = BoringModel() + training_type_plugin = DDPSpawnPlugin( + ddp_comm_hook=default.fp16_compress_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}"