Override `broadcast_object_list` for `torch<1.8` (#7592)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Andrew Tritt 2021-05-20 01:29:55 -07:00 committed by GitHub
parent ed271905cf
commit 92cf396de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -106,6 +106,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))

View File

@ -3,7 +3,7 @@ import pickle
import torch
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
log = logging.getLogger(__name__)
@ -88,7 +88,7 @@ def _broadcast_object_list(object_list, src=0, group=None):
object_list[i] = _tensor_to_object(obj_view, obj_size)
if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.is_available():
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list