208 lines
6.3 KiB
Python
208 lines
6.3 KiB
Python
import contextlib
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
|
|
|
|
|
|
def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
|
|
loss_vals = [1 / i for i in range(1, 10)]
|
|
if fabric.global_rank == global_rank_spike:
|
|
if spike_value is None:
|
|
loss_vals[4] = 3
|
|
else:
|
|
loss_vals[4] = spike_value
|
|
|
|
for i in range(len(loss_vals)):
|
|
context = pytest.raises(TrainingSpikeException) if i == 4 and should_raise else contextlib.nullcontext()
|
|
|
|
with context:
|
|
fabric.call(
|
|
"on_train_batch_end",
|
|
fabric=fabric,
|
|
loss=torch.tensor(loss_vals[i], device=fabric.device),
|
|
batch=None,
|
|
batch_idx=i,
|
|
)
|
|
|
|
|
|
@pytest.mark.flaky(max_runs=3)
|
|
@pytest.mark.parametrize(
|
|
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
|
|
[
|
|
pytest.param(0, 1, None, True),
|
|
pytest.param(0, 1, None, False),
|
|
pytest.param(0, 1, float("inf"), True),
|
|
pytest.param(0, 1, float("inf"), False),
|
|
pytest.param(0, 1, float("-inf"), True),
|
|
pytest.param(0, 1, float("-inf"), False),
|
|
pytest.param(0, 1, float("NaN"), True),
|
|
pytest.param(0, 1, float("NaN"), False),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
None,
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
None,
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
None,
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
None,
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("inf"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("inf"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("inf"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("inf"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("-inf"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("-inf"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("-inf"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("-inf"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("NaN"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
float("NaN"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("NaN"),
|
|
True,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
float("NaN"),
|
|
False,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
|
|
def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_devices, spike_value, finite_only):
|
|
fabric = Fabric(
|
|
accelerator="cpu",
|
|
devices=num_devices,
|
|
callbacks=[SpikeDetection(exclude_batches_path=tmp_path, finite_only=finite_only)],
|
|
strategy="ddp_spawn",
|
|
)
|
|
|
|
# spike_value == None -> typical spike detection
|
|
# finite_only -> typical spike detection and raise with NaN +/- inf
|
|
# if inf -> inf >> other values -> typical spike detection
|
|
should_raise = spike_value is None or finite_only or spike_value == float("inf")
|
|
fabric.launch(
|
|
spike_detection_test,
|
|
global_rank_spike=global_rank_spike,
|
|
spike_value=spike_value,
|
|
should_raise=should_raise,
|
|
)
|