21 lines
528 B
Python
21 lines
528 B
Python
import math
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
|
|
|
|
|
|
@pytest.mark.parametrize("value", (math.nan, math.inf, -math.inf))
|
|
def test_detect_nan_parameters(value):
|
|
model = nn.Linear(2, 3)
|
|
|
|
detect_nan_parameters(model)
|
|
|
|
nn.init.constant_(model.bias, value)
|
|
assert not torch.isfinite(model.bias).all()
|
|
|
|
with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `bias`.*"):
|
|
detect_nan_parameters(model)
|