diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 1ce782f967..16bc39bd7f 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -20,8 +20,8 @@ from typing import Optional import numpy as np import torch - from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn def seed_everything(seed: Optional[int] = None) -> int: @@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int: try: if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value)) + seed = os.environ.get("PL_GLOBAL_SEED") seed = int(seed) except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No correct seed found, seed set to {seed}") - if (seed > max_seed_value) or (seed < min_seed_value): - log.warning( - f"{seed} is not in bounds, \ - numpy accepts from {min_seed_value} to {max_seed_value}" - ) + if not (min_seed_value <= seed <= max_seed_value): + rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) + log.info(f"Global seed set to {seed}") os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) @@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int: def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: - seed = random.randint(min_seed_value, max_seed_value) - log.warning(f"No correct seed found, seed set to {seed}") - return seed + return random.randint(min_seed_value, max_seed_value) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py new file mode 100644 index 0000000000..7fa6df516c --- /dev/null +++ b/tests/utilities/test_seed.py @@ -0,0 +1,55 @@ +import os + +from unittest import mock +import pytest + +import pytorch_lightning.utilities.seed as seed_utils + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_stays_same_with_multiple_seed_everything_calls(): + """ + Ensure that after the initial seed everything, + the seed stays the same for the same run. + """ + with pytest.warns(UserWarning, match="No correct seed found"): + seed_utils.seed_everything() + initial_seed = os.environ.get("PL_GLOBAL_SEED") + + with pytest.warns(None) as record: + seed_utils.seed_everything() + assert not record # does not warn + seed = os.environ.get("PL_GLOBAL_SEED") + + assert initial_seed == seed + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) +def test_correct_seed_with_environment_variable(): + """ + Ensure that the PL_GLOBAL_SEED environment is read + """ + assert seed_utils.seed_everything() == 2020 + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) +def test_invalid_seed(): + """ + Ensure that we still fix the seed even if an invalid seed is given + """ + with pytest.warns(UserWarning, match="No correct seed found"): + seed = seed_utils.seed_everything() + assert seed == 123 + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) +@pytest.mark.parametrize("seed", (10e9, -10e9)) +def test_out_of_bounds_seed(seed): + """ + Ensure that we still fix the seed even if an out-of-bounds seed is given + """ + with pytest.warns(UserWarning, match="is not in bounds"): + actual = seed_utils.seed_everything(seed) + assert actual == 123