debug
This commit is contained in:
parent
d04216d02d
commit
9d173e8499
|
@ -292,7 +292,8 @@ def _init_dist_connection(
|
|||
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
||||
|
||||
# PyTorch >= 2.4 warns about undestroyed process group, so we need to do it at program exit
|
||||
if torch_distributed_backend == "nccl":
|
||||
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
|
||||
atexit.register(_destroy_dist_connection)
|
||||
|
||||
# On rank=0 let everyone know training is starting
|
||||
|
@ -306,8 +307,8 @@ def _init_dist_connection(
|
|||
|
||||
def _destroy_dist_connection() -> None:
|
||||
if _distributed_is_initialized():
|
||||
# ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
|
||||
torch.distributed.barrier()
|
||||
# # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
|
||||
# torch.distributed.barrier()
|
||||
print("destroying dist")
|
||||
torch.distributed.destroy_process_group()
|
||||
print("dist destroyed")
|
||||
|
|
Loading…
Reference in New Issue