From d04216d02dc9b97b48311c4a7de8491f8ef48e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jun 2024 23:33:32 +0200 Subject: [PATCH] debug --- src/lightning/fabric/utilities/distributed.py | 2 ++ src/lightning/fabric/utilities/spike.py | 1 + tests/tests_fabric/utilities/test_spike.py | 13 +++++++------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 2673d1f089..dfa308cdcd 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -308,7 +308,9 @@ def _destroy_dist_connection() -> None: if _distributed_is_initialized(): # ensure at least one collective op ran, otherwise `destroy_process_group()` hangs torch.distributed.barrier() + print("destroying dist") torch.distributed.destroy_process_group() + print("dist destroyed") def _get_default_process_group_backend_for_device(device: torch.device) -> str: diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index d9c41dab0a..33a832685f 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -87,6 +87,7 @@ class SpikeDetection: # While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False) + print(f"{is_spike_global=}") if is_spike_global: self._handle_spike(fabric, batch_idx) diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index f97920fc83..2f9b131688 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -32,12 +32,12 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), [ - # pytest.param(0, 1, None, True), - # pytest.param(0, 1, None, False), - # pytest.param(0, 1, float("inf"), True), - # pytest.param(0, 1, float("inf"), False), - # pytest.param(0, 1, float("-inf"), True), - # pytest.param(0, 1, float("-inf"), False), + pytest.param(0, 1, None, True), + pytest.param(0, 1, None, False), + pytest.param(0, 1, float("inf"), True), + pytest.param(0, 1, float("inf"), False), + pytest.param(0, 1, float("-inf"), True), + pytest.param(0, 1, float("-inf"), False), pytest.param(0, 1, float("NaN"), True), pytest.param(0, 1, float("NaN"), False), # pytest.param( @@ -205,3 +205,4 @@ def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_dev spike_value=spike_value, should_raise=should_raise, ) + print("test end")