debug
This commit is contained in:
parent
d04216d02d
commit
9d173e8499
|
@ -292,7 +292,8 @@ def _init_dist_connection(
|
||||||
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||||
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
||||||
|
|
||||||
# PyTorch >= 2.4 warns about undestroyed process group, so we need to do it at program exit
|
if torch_distributed_backend == "nccl":
|
||||||
|
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
|
||||||
atexit.register(_destroy_dist_connection)
|
atexit.register(_destroy_dist_connection)
|
||||||
|
|
||||||
# On rank=0 let everyone know training is starting
|
# On rank=0 let everyone know training is starting
|
||||||
|
@ -306,8 +307,8 @@ def _init_dist_connection(
|
||||||
|
|
||||||
def _destroy_dist_connection() -> None:
|
def _destroy_dist_connection() -> None:
|
||||||
if _distributed_is_initialized():
|
if _distributed_is_initialized():
|
||||||
# ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
|
# # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
|
||||||
torch.distributed.barrier()
|
# torch.distributed.barrier()
|
||||||
print("destroying dist")
|
print("destroying dist")
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
print("dist destroyed")
|
print("dist destroyed")
|
||||||
|
|
Loading…
Reference in New Issue