2021-01-12 04:30:27 +00:00
|
|
|
import os
|
2022-03-01 23:27:30 +00:00
|
|
|
import random
|
2022-08-09 13:40:30 +00:00
|
|
|
from typing import Mapping
|
2021-01-12 04:30:27 +00:00
|
|
|
from unittest import mock
|
2022-08-09 13:40:30 +00:00
|
|
|
from unittest.mock import MagicMock
|
2021-01-23 23:52:04 +00:00
|
|
|
|
2022-03-01 23:27:30 +00:00
|
|
|
import numpy as np
|
2021-01-12 04:30:27 +00:00
|
|
|
import pytest
|
2021-04-27 09:51:39 +00:00
|
|
|
import torch
|
2021-01-12 04:30:27 +00:00
|
|
|
|
|
|
|
import pytorch_lightning.utilities.seed as seed_utils
|
2022-03-01 23:27:30 +00:00
|
|
|
from pytorch_lightning.utilities.seed import isolate_rng
|
2021-01-12 04:30:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {}, clear=True)
|
|
|
|
def test_seed_stays_same_with_multiple_seed_everything_calls():
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
|
2021-09-23 15:09:48 +00:00
|
|
|
with pytest.warns(UserWarning, match="No seed found"):
|
2021-01-12 04:30:27 +00:00
|
|
|
seed_utils.seed_everything()
|
|
|
|
initial_seed = os.environ.get("PL_GLOBAL_SEED")
|
|
|
|
|
|
|
|
with pytest.warns(None) as record:
|
|
|
|
seed_utils.seed_everything()
|
|
|
|
assert not record # does not warn
|
|
|
|
seed = os.environ.get("PL_GLOBAL_SEED")
|
|
|
|
|
|
|
|
assert initial_seed == seed
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
|
|
|
|
def test_correct_seed_with_environment_variable():
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Ensure that the PL_GLOBAL_SEED environment is read."""
|
2021-01-12 04:30:27 +00:00
|
|
|
assert seed_utils.seed_everything() == 2020
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
|
2021-07-26 11:37:35 +00:00
|
|
|
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123)
|
2021-01-12 04:30:27 +00:00
|
|
|
def test_invalid_seed():
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Ensure that we still fix the seed even if an invalid seed is given."""
|
2021-09-23 15:09:48 +00:00
|
|
|
with pytest.warns(UserWarning, match="Invalid seed found"):
|
2021-01-12 04:30:27 +00:00
|
|
|
seed = seed_utils.seed_everything()
|
|
|
|
assert seed == 123
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {}, clear=True)
|
2021-07-26 11:37:35 +00:00
|
|
|
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123)
|
2021-01-12 04:30:27 +00:00
|
|
|
@pytest.mark.parametrize("seed", (10e9, -10e9))
|
|
|
|
def test_out_of_bounds_seed(seed):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
|
2021-01-12 04:30:27 +00:00
|
|
|
with pytest.warns(UserWarning, match="is not in bounds"):
|
|
|
|
actual = seed_utils.seed_everything(seed)
|
|
|
|
assert actual == 123
|
2021-04-27 09:51:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_reset_seed_no_op():
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
|
2021-04-27 09:51:39 +00:00
|
|
|
assert "PL_GLOBAL_SEED" not in os.environ
|
|
|
|
seed_before = torch.initial_seed()
|
|
|
|
seed_utils.reset_seed()
|
|
|
|
assert torch.initial_seed() == seed_before
|
|
|
|
assert "PL_GLOBAL_SEED" not in os.environ
|
|
|
|
|
|
|
|
|
2021-10-28 12:57:41 +00:00
|
|
|
@pytest.mark.parametrize("workers", (True, False))
|
|
|
|
def test_reset_seed_everything(workers):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Test that we can reset the seed to the initial value set by seed_everything()"""
|
2021-04-27 09:51:39 +00:00
|
|
|
assert "PL_GLOBAL_SEED" not in os.environ
|
2021-10-28 12:57:41 +00:00
|
|
|
assert "PL_SEED_WORKERS" not in os.environ
|
|
|
|
|
|
|
|
seed_utils.seed_everything(123, workers)
|
2021-04-27 09:51:39 +00:00
|
|
|
before = torch.rand(1)
|
2021-10-28 12:57:41 +00:00
|
|
|
assert os.environ["PL_GLOBAL_SEED"] == "123"
|
|
|
|
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
|
|
|
|
|
2021-04-27 09:51:39 +00:00
|
|
|
seed_utils.reset_seed()
|
|
|
|
after = torch.rand(1)
|
2021-10-28 12:57:41 +00:00
|
|
|
assert os.environ["PL_GLOBAL_SEED"] == "123"
|
|
|
|
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
|
2021-04-27 09:51:39 +00:00
|
|
|
assert torch.allclose(before, after)
|
2022-03-01 23:27:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_isolate_rng():
|
|
|
|
"""Test that the isolate_rng context manager isolates the random state from the outer scope."""
|
|
|
|
# torch
|
|
|
|
torch.rand(1)
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [torch.rand(2) for _ in range(3)]
|
|
|
|
assert torch.equal(torch.rand(2), generated[0])
|
|
|
|
|
|
|
|
# numpy
|
|
|
|
np.random.rand(1)
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [np.random.rand(2) for _ in range(3)]
|
|
|
|
assert np.equal(np.random.rand(2), generated[0]).all()
|
|
|
|
|
|
|
|
# python
|
|
|
|
random.random()
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [random.random() for _ in range(3)]
|
|
|
|
assert random.random() == generated[0]
|
2022-08-09 13:40:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("pytorch_lightning.utilities.seed.log.info")
|
|
|
|
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
|
|
|
|
def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]):
|
|
|
|
"""Test that log message prefix with correct rank info."""
|
|
|
|
with mock.patch.dict(os.environ, env_vars, clear=True):
|
|
|
|
from pytorch_lightning.utilities.rank_zero import _get_rank
|
|
|
|
|
|
|
|
rank = _get_rank()
|
|
|
|
|
|
|
|
seed_utils.seed_everything(123)
|
|
|
|
|
|
|
|
expected_log = f"[rank: {rank}] Global seed set to 123"
|
|
|
|
|
|
|
|
log_mock.assert_called_once_with(expected_log)
|