lightning/tests/tests_lite/utilities/test_seed.py

85 lines
3.2 KiB
Python
Raw Normal View History

import os
from unittest import mock
import pytest
import torch
import lightning_lite.utilities
from lightning_lite.utilities import seed as seed_utils
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
with pytest.warns(UserWarning, match="No seed found"):
lightning_lite.utilities.seed.seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")
with pytest.warns(None) as record:
lightning_lite.utilities.seed.seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")
assert initial_seed == seed
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""Ensure that the PL_GLOBAL_SEED environment is read."""
assert lightning_lite.utilities.seed.seed_everything() == 2020
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = lightning_lite.utilities.seed.seed_everything()
assert seed == 123
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123)
@pytest.mark.parametrize("seed", (10e9, -10e9))
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = lightning_lite.utilities.seed.seed_everything(seed)
assert actual == 123
def test_reset_seed_no_op():
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
assert "PL_GLOBAL_SEED" not in os.environ
seed_before = torch.initial_seed()
lightning_lite.utilities.seed.reset_seed()
assert torch.initial_seed() == seed_before
assert "PL_GLOBAL_SEED" not in os.environ
@pytest.mark.parametrize("workers", (True, False))
def test_reset_seed_everything(workers):
"""Test that we can reset the seed to the initial value set by seed_everything()"""
assert "PL_GLOBAL_SEED" not in os.environ
assert "PL_SEED_WORKERS" not in os.environ
lightning_lite.utilities.seed.seed_everything(123, workers)
before = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
lightning_lite.utilities.seed.reset_seed()
after = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
assert torch.allclose(before, after)
def test_backward_compatibility_rng_states_dict():
"""Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
states = _collect_rng_states()
assert "torch.cuda" in states
states.pop("torch.cuda")
_set_rng_states(states)