constant root seed in reset_seed (tests) (#2895)

* fix root_seed in reset_seed

* seed value
This commit is contained in:
Adrian Wälchli 2020-08-09 23:23:01 +02:00 committed by GitHub
parent 4d3dfd43e4
commit 1ac507a255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 7 deletions

View File

@ -9,10 +9,6 @@ TEMP_PATH = os.path.join(PACKAGE_ROOT, 'test_temp')
# generate a list of random seeds for each test
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
ROOT_SEED = 1234
torch.manual_seed(ROOT_SEED)
np.random.seed(ROOT_SEED)
RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000))
if not os.path.isdir(TEMP_PATH):
os.mkdir(TEMP_PATH)

View File

@ -7,7 +7,7 @@ import numpy as np
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS
from tests import TEMP_PATH, RANDOM_PORTS
from tests.base.model_template import EvalModelTemplate
@ -72,8 +72,7 @@ def assert_ok_model_acc(trainer, key='test_acc', thr=0.5):
assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}"
def reset_seed():
seed = RANDOM_SEEDS.pop()
def reset_seed(seed=0):
seed_everything(seed)