lightning/tests/tests_fabric/utilities/test_spike.py

208 lines
6.3 KiB
Python

import contextlib
import sys
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
loss_vals = [1 / i for i in range(1, 10)]
if fabric.global_rank == global_rank_spike:
if spike_value is None:
loss_vals[4] = 3
else:
loss_vals[4] = spike_value
for i in range(len(loss_vals)):
context = pytest.raises(TrainingSpikeException) if i == 4 and should_raise else contextlib.nullcontext()
with context:
fabric.call(
"on_train_batch_end",
fabric=fabric,
loss=torch.tensor(loss_vals[i], device=fabric.device),
batch=None,
batch_idx=i,
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
[
pytest.param(0, 1, None, True),
pytest.param(0, 1, None, False),
pytest.param(0, 1, float("inf"), True),
pytest.param(0, 1, float("inf"), False),
pytest.param(0, 1, float("-inf"), True),
pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, float("NaN"), True),
pytest.param(0, 1, float("NaN"), False),
pytest.param(
0,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
],
)
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_devices, spike_value, finite_only):
fabric = Fabric(
accelerator="cpu",
devices=num_devices,
callbacks=[SpikeDetection(exclude_batches_path=tmp_path, finite_only=finite_only)],
strategy="ddp_spawn",
)
# spike_value == None -> typical spike detection
# finite_only -> typical spike detection and raise with NaN +/- inf
# if inf -> inf >> other values -> typical spike detection
should_raise = spike_value is None or finite_only or spike_value == float("inf")
fabric.launch(
spike_detection_test,
global_rank_spike=global_rank_spike,
spike_value=spike_value,
should_raise=should_raise,
)