constant root seed in reset_seed (tests) (#2895)
* fix root_seed in reset_seed * seed value
This commit is contained in:
parent
4d3dfd43e4
commit
1ac507a255
|
@ -9,10 +9,6 @@ TEMP_PATH = os.path.join(PACKAGE_ROOT, 'test_temp')
|
|||
|
||||
# generate a list of random seeds for each test
|
||||
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
||||
ROOT_SEED = 1234
|
||||
torch.manual_seed(ROOT_SEED)
|
||||
np.random.seed(ROOT_SEED)
|
||||
RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000))
|
||||
|
||||
if not os.path.isdir(TEMP_PATH):
|
||||
os.mkdir(TEMP_PATH)
|
||||
|
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
|
||||
from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS
|
||||
from tests import TEMP_PATH, RANDOM_PORTS
|
||||
from tests.base.model_template import EvalModelTemplate
|
||||
|
||||
|
||||
|
@ -72,8 +72,7 @@ def assert_ok_model_acc(trainer, key='test_acc', thr=0.5):
|
|||
assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}"
|
||||
|
||||
|
||||
def reset_seed():
|
||||
seed = RANDOM_SEEDS.pop()
|
||||
def reset_seed(seed=0):
|
||||
seed_everything(seed)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue