diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index c3c05f49b4..574a251e25 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -37,6 +37,7 @@ from pytorch_lightning.utilities import ( _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_9, rank_zero_deprecation, rank_zero_warn, ) @@ -289,9 +290,11 @@ class DDPPlugin(ParallelPlugin): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode - # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 - if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device: + # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # Since 1.9, DDP communication hooks can work on all backends. + if _TORCH_GREATER_EQUAL_1_9 or ( + _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device + ): register_ddp_comm_hook( model=self._model, ddp_comm_state=self._ddp_comm_state,