debug
This commit is contained in:
parent
d289dbcbb6
commit
d04216d02d
|
@ -308,7 +308,9 @@ def _destroy_dist_connection() -> None:
|
|||
if _distributed_is_initialized():
|
||||
# ensure at least one collective op ran, otherwise `destroy_process_group()` hangs
|
||||
torch.distributed.barrier()
|
||||
print("destroying dist")
|
||||
torch.distributed.destroy_process_group()
|
||||
print("dist destroyed")
|
||||
|
||||
|
||||
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
|
||||
|
|
|
@ -87,6 +87,7 @@ class SpikeDetection:
|
|||
|
||||
# While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike
|
||||
is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False)
|
||||
print(f"{is_spike_global=}")
|
||||
|
||||
if is_spike_global:
|
||||
self._handle_spike(fabric, batch_idx)
|
||||
|
|
|
@ -32,12 +32,12 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
|
|||
@pytest.mark.parametrize(
|
||||
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
|
||||
[
|
||||
# pytest.param(0, 1, None, True),
|
||||
# pytest.param(0, 1, None, False),
|
||||
# pytest.param(0, 1, float("inf"), True),
|
||||
# pytest.param(0, 1, float("inf"), False),
|
||||
# pytest.param(0, 1, float("-inf"), True),
|
||||
# pytest.param(0, 1, float("-inf"), False),
|
||||
pytest.param(0, 1, None, True),
|
||||
pytest.param(0, 1, None, False),
|
||||
pytest.param(0, 1, float("inf"), True),
|
||||
pytest.param(0, 1, float("inf"), False),
|
||||
pytest.param(0, 1, float("-inf"), True),
|
||||
pytest.param(0, 1, float("-inf"), False),
|
||||
pytest.param(0, 1, float("NaN"), True),
|
||||
pytest.param(0, 1, float("NaN"), False),
|
||||
# pytest.param(
|
||||
|
@ -205,3 +205,4 @@ def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_dev
|
|||
spike_value=spike_value,
|
||||
should_raise=should_raise,
|
||||
)
|
||||
print("test end")
|
||||
|
|
Loading…
Reference in New Issue