[BUG] Check environ before selecting a seed to prevent warning message (#4743)
* Check environment var independently to selecting a seed to prevent unnecessary warning message
* Add if statement to check if PL_GLOBAL_SEED has been set
* Added seed test to ensure that the seed stays the same, in case
* if
* Delete global seed after test has finished
* Fix code, add tests
* Ensure seed does not exist before tests start
* Refactor test based on review, add log call
* Ensure we clear the os environ in patched dict
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
(cherry picked from commit 635df27880
)
This commit is contained in:
parent
ee934de824
commit
0c370ade51
|
@ -20,8 +20,8 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None) -> int:
|
||||
|
@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int:
|
|||
|
||||
try:
|
||||
if seed is None:
|
||||
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
|
||||
seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
seed = int(seed)
|
||||
except (TypeError, ValueError):
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
rank_zero_warn(f"No correct seed found, seed set to {seed}")
|
||||
|
||||
if (seed > max_seed_value) or (seed < min_seed_value):
|
||||
log.warning(
|
||||
f"{seed} is not in bounds, \
|
||||
numpy accepts from {min_seed_value} to {max_seed_value}"
|
||||
)
|
||||
if not (min_seed_value <= seed <= max_seed_value):
|
||||
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
|
||||
log.info(f"Global seed set to {seed}")
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int:
|
|||
|
||||
|
||||
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
|
||||
seed = random.randint(min_seed_value, max_seed_value)
|
||||
log.warning(f"No correct seed found, seed set to {seed}")
|
||||
return seed
|
||||
return random.randint(min_seed_value, max_seed_value)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
|
||||
from unittest import mock
|
||||
import pytest
|
||||
|
||||
import pytorch_lightning.utilities.seed as seed_utils
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_seed_stays_same_with_multiple_seed_everything_calls():
|
||||
"""
|
||||
Ensure that after the initial seed everything,
|
||||
the seed stays the same for the same run.
|
||||
"""
|
||||
with pytest.warns(UserWarning, match="No correct seed found"):
|
||||
seed_utils.seed_everything()
|
||||
initial_seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
seed_utils.seed_everything()
|
||||
assert not record # does not warn
|
||||
seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
|
||||
assert initial_seed == seed
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
|
||||
def test_correct_seed_with_environment_variable():
|
||||
"""
|
||||
Ensure that the PL_GLOBAL_SEED environment is read
|
||||
"""
|
||||
assert seed_utils.seed_everything() == 2020
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
|
||||
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
|
||||
def test_invalid_seed():
|
||||
"""
|
||||
Ensure that we still fix the seed even if an invalid seed is given
|
||||
"""
|
||||
with pytest.warns(UserWarning, match="No correct seed found"):
|
||||
seed = seed_utils.seed_everything()
|
||||
assert seed == 123
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
|
||||
@pytest.mark.parametrize("seed", (10e9, -10e9))
|
||||
def test_out_of_bounds_seed(seed):
|
||||
"""
|
||||
Ensure that we still fix the seed even if an out-of-bounds seed is given
|
||||
"""
|
||||
with pytest.warns(UserWarning, match="is not in bounds"):
|
||||
actual = seed_utils.seed_everything(seed)
|
||||
assert actual == 123
|
Loading…
Reference in New Issue