lightning/tests/tests_fabric/utilities/test_spike.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

208 lines
6.3 KiB
Python
Raw Normal View History

import contextlib
import sys
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
loss_vals = [1 / i for i in range(1, 10)]
if fabric.global_rank == global_rank_spike:
if spike_value is None:
loss_vals[4] = 3
else:
loss_vals[4] = spike_value
for i in range(len(loss_vals)):
context = pytest.raises(TrainingSpikeException) if i == 4 and should_raise else contextlib.nullcontext()
with context:
fabric.call(
"on_train_batch_end",
fabric=fabric,
loss=torch.tensor(loss_vals[i], device=fabric.device),
batch=None,
batch_idx=i,
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
[
pytest.param(0, 1, None, True),
pytest.param(0, 1, None, False),
pytest.param(0, 1, float("inf"), True),
pytest.param(0, 1, float("inf"), False),
pytest.param(0, 1, float("-inf"), True),
pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, float("NaN"), True),
pytest.param(0, 1, float("NaN"), False),
pytest.param(
0,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
],
)
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_devices, spike_value, finite_only):
fabric = Fabric(
accelerator="cpu",
devices=num_devices,
callbacks=[SpikeDetection(exclude_batches_path=tmp_path, finite_only=finite_only)],
strategy="ddp_spawn",
)
# spike_value == None -> typical spike detection
# finite_only -> typical spike detection and raise with NaN +/- inf
# if inf -> inf >> other values -> typical spike detection
should_raise = spike_value is None or finite_only or spike_value == float("inf")
fabric.launch(
spike_detection_test,
global_rank_spike=global_rank_spike,
spike_value=spike_value,
should_raise=should_raise,
)