This commit is contained in:
Adrian Wälchli 2024-06-05 23:33:32 +02:00
parent d289dbcbb6
commit d04216d02d
3 changed files with 10 additions and 6 deletions

View File

@ -308,7 +308,9 @@ def _destroy_dist_connection() -> None:
if _distributed_is_initialized():
# ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
torch.distributed.barrier()
print("destroying dist")
torch.distributed.destroy_process_group()
print("dist destroyed")
def _get_default_process_group_backend_for_device(device: torch.device) -> str:

View File

@ -87,6 +87,7 @@ class SpikeDetection:
# While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike
is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False)
print(f"{is_spike_global=}")
if is_spike_global:
self._handle_spike(fabric, batch_idx)

View File

@ -32,12 +32,12 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
[
# pytest.param(0, 1, None, True),
# pytest.param(0, 1, None, False),
# pytest.param(0, 1, float("inf"), True),
# pytest.param(0, 1, float("inf"), False),
# pytest.param(0, 1, float("-inf"), True),
# pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, None, True),
pytest.param(0, 1, None, False),
pytest.param(0, 1, float("inf"), True),
pytest.param(0, 1, float("inf"), False),
pytest.param(0, 1, float("-inf"), True),
pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, float("NaN"), True),
pytest.param(0, 1, float("NaN"), False),
# pytest.param(
@ -205,3 +205,4 @@ def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_dev
spike_value=spike_value,
should_raise=should_raise,
)
print("test end")