fix NCCL error with non-consecutive trainer gpus (#8165)
* device ids in barrier x x s same fix for spawn fix non-nccl x * add changelog * get nccl backend * get backend Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
2f3c65e57b
commit
bf54ac1cad
|
@ -327,6 +327,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where an infinite recursion would be triggered when using the `BaseFinetuning` callback on a model that contains a `ModuleDict` ([#8170](https://github.com/PyTorchLightning/pytorch-lightning/pull/8170))
|
||||
|
||||
|
||||
- Fixed NCCL error when selecting non-consecutive device ids ([#8165](https://github.com/PyTorchLightning/pytorch-lightning/pull/8165))
|
||||
|
||||
|
||||
- Fixed `log_gpu_memory` metrics not being added to `logging` when nothing else is logged ([#8174](https://github.com/PyTorchLightning/pytorch-lightning/pull/8174))
|
||||
|
||||
|
||||
|
|
|
@ -332,8 +332,12 @@ class DDPPlugin(ParallelPlugin):
|
|||
def post_dispatch(self) -> None:
|
||||
self.cluster_environment.teardown()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
if torch_distrib.is_available() and torch_distrib.is_initialized():
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch_distrib.is_initialized():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
else:
|
||||
torch_distrib.barrier()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
|
|
@ -309,8 +309,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
|
||||
self.lightning_module.load_state_dict(ckpt)
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
if torch_distrib.is_initialized():
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch_distrib.is_initialized():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
else:
|
||||
torch_distrib.barrier()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
|
|
@ -11,7 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
@ -46,3 +49,30 @@ def test_ddp_with_2_gpus():
|
|||
assert model.device == torch.device("cpu")
|
||||
cuda_memory = torch.cuda.memory_allocated()
|
||||
assert cuda_memory < model.start_cuda_memory
|
||||
|
||||
|
||||
class BarrierModel(BoringModel):
|
||||
|
||||
def setup(self, stage=None):
|
||||
assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel)
|
||||
self.trainer.accelerator.barrier("barrier before model is wrapped")
|
||||
|
||||
def on_train_start(self):
|
||||
assert isinstance(self.trainer.accelerator.model, DistributedDataParallel)
|
||||
self.trainer.accelerator.barrier("barrier after model is wrapped")
|
||||
|
||||
|
||||
@RunIf(min_gpus=4, special=True)
|
||||
@mock.patch("torch.distributed.barrier")
|
||||
def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir):
|
||||
""" Test correct usage of barriers when device ids do not start at 0 or are not consecutive. """
|
||||
model = BoringModel()
|
||||
gpus = [1, 3]
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=1,
|
||||
gpus=gpus,
|
||||
accelerator="ddp",
|
||||
)
|
||||
trainer.fit(model)
|
||||
barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]])
|
||||
|
|
Loading…
Reference in New Issue