diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b5e1974ab..b7178a4b01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,6 +106,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) + + - Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608)) diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py index 67b64c046d..ebee83828e 100644 --- a/pytorch_lightning/overrides/torch_distributed.py +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -3,7 +3,7 @@ import pickle import torch -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 log = logging.getLogger(__name__) @@ -88,7 +88,7 @@ def _broadcast_object_list(object_list, src=0, group=None): object_list[i] = _tensor_to_object(obj_view, obj_size) -if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): +if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.is_available(): from torch.distributed.distributed_c10d import broadcast_object_list else: broadcast_object_list = _broadcast_object_list