[BUG] Check environ before selecting a seed to prevent warning message (#4743)

* Check environment var independently to selecting a seed to prevent unnecessary warning message

* Add if statement to check if PL_GLOBAL_SEED has been set

* Added seed test to ensure that the seed stays the same, in case

* if

* Delete global seed after test has finished

* Fix code, add tests

* Ensure seed does not exist before tests start

* Refactor test based on review, add log call

* Ensure we clear the os environ in patched dict

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
(cherry picked from commit 635df27880)
This commit is contained in:
Sean Naren 2021-01-12 04:30:27 +00:00 committed by Jirka Borovec
parent ee934de824
commit 0c370ade51
2 changed files with 62 additions and 10 deletions

View File

@ -20,8 +20,8 @@ from typing import Optional
import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
def seed_everything(seed: Optional[int] = None) -> int:
@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int:
try:
if seed is None:
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
seed = os.environ.get("PL_GLOBAL_SEED")
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"No correct seed found, seed set to {seed}")
if (seed > max_seed_value) or (seed < min_seed_value):
log.warning(
f"{seed} is not in bounds, \
numpy accepts from {min_seed_value} to {max_seed_value}"
)
if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)
log.info(f"Global seed set to {seed}")
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int:
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
seed = random.randint(min_seed_value, max_seed_value)
log.warning(f"No correct seed found, seed set to {seed}")
return seed
return random.randint(min_seed_value, max_seed_value)

View File

@ -0,0 +1,55 @@
import os
from unittest import mock
import pytest
import pytorch_lightning.utilities.seed as seed_utils
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""
Ensure that after the initial seed everything,
the seed stays the same for the same run.
"""
with pytest.warns(UserWarning, match="No correct seed found"):
seed_utils.seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")
with pytest.warns(None) as record:
seed_utils.seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")
assert initial_seed == seed
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""
Ensure that the PL_GLOBAL_SEED environment is read
"""
assert seed_utils.seed_everything() == 2020
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
def test_invalid_seed():
"""
Ensure that we still fix the seed even if an invalid seed is given
"""
with pytest.warns(UserWarning, match="No correct seed found"):
seed = seed_utils.seed_everything()
assert seed == 123
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
@pytest.mark.parametrize("seed", (10e9, -10e9))
def test_out_of_bounds_seed(seed):
"""
Ensure that we still fix the seed even if an out-of-bounds seed is given
"""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = seed_utils.seed_everything(seed)
assert actual == 123