This commit is contained in:
Adrian Wälchli 2024-06-05 23:42:24 +02:00
parent d04216d02d
commit 9d173e8499
1 changed files with 5 additions and 4 deletions

View File

@ -292,7 +292,8 @@ def _init_dist_connection(
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
# PyTorch >= 2.4 warns about undestroyed process group, so we need to do it at program exit if torch_distributed_backend == "nccl":
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
atexit.register(_destroy_dist_connection) atexit.register(_destroy_dist_connection)
# On rank=0 let everyone know training is starting # On rank=0 let everyone know training is starting
@ -306,8 +307,8 @@ def _init_dist_connection(
def _destroy_dist_connection() -> None: def _destroy_dist_connection() -> None:
if _distributed_is_initialized(): if _distributed_is_initialized():
# ensure at least one collective op ran, otherwise `destroy_process_group()` hangs # # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
torch.distributed.barrier() # torch.distributed.barrier()
print("destroying dist") print("destroying dist")
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
print("dist destroyed") print("dist destroyed")