diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index dfa308cdcd..f3cf6c5ff1 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -292,8 +292,9 @@ def _init_dist_connection( log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) - # PyTorch >= 2.4 warns about undestroyed process group, so we need to do it at program exit - atexit.register(_destroy_dist_connection) + if torch_distributed_backend == "nccl": + # PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit + atexit.register(_destroy_dist_connection) # On rank=0 let everyone know training is starting rank_zero_info( @@ -306,8 +307,8 @@ def _init_dist_connection( def _destroy_dist_connection() -> None: if _distributed_is_initialized(): - # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs - torch.distributed.barrier() + # # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs + # torch.distributed.barrier() print("destroying dist") torch.distributed.destroy_process_group() print("dist destroyed")