From 1ac507a255b551788b1660c23900000f5f2605a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 9 Aug 2020 23:23:01 +0200 Subject: [PATCH] constant root seed in reset_seed (tests) (#2895) * fix root_seed in reset_seed * seed value --- tests/__init__.py | 4 ---- tests/base/develop_utils.py | 5 ++--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index acc27596f9..29c6a4a633 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,10 +9,6 @@ TEMP_PATH = os.path.join(PACKAGE_ROOT, 'test_temp') # generate a list of random seeds for each test RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) -ROOT_SEED = 1234 -torch.manual_seed(ROOT_SEED) -np.random.seed(ROOT_SEED) -RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000)) if not os.path.isdir(TEMP_PATH): os.mkdir(TEMP_PATH) diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py index 37fde1d872..bb39c3887a 100644 --- a/tests/base/develop_utils.py +++ b/tests/base/develop_utils.py @@ -7,7 +7,7 @@ import numpy as np from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger -from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS +from tests import TEMP_PATH, RANDOM_PORTS from tests.base.model_template import EvalModelTemplate @@ -72,8 +72,7 @@ def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" -def reset_seed(): - seed = RANDOM_SEEDS.pop() +def reset_seed(seed=0): seed_everything(seed)