lightning/tests/tests_pytorch/utilities/test_finite_checks.py

21 lines
528 B
Python

import math
import pytest
import torch
import torch.nn as nn
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
@pytest.mark.parametrize("value", (math.nan, math.inf, -math.inf))
def test_detect_nan_parameters(value):
model = nn.Linear(2, 3)
detect_nan_parameters(model)
nn.init.constant_(model.bias, value)
assert not torch.isfinite(model.bias).all()
with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `bias`.*"):
detect_nan_parameters(model)