diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index 551689413..43edd858d 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -3,8 +3,6 @@ from __future__ import unicode_literals, division, print_function import plac from timeit import default_timer as timer -import random -import numpy.random from ..gold import GoldCorpus from ..util import prints @@ -12,10 +10,6 @@ from .. import util from .. import displacy -random.seed(0) -numpy.random.seed(0) - - @plac.annotations( model=("model name or path", "positional", None, str), data_path=("location of JSON-formatted evaluation data", "positional", @@ -31,6 +25,8 @@ def evaluate(model, data_path, gpu_id=-1, gold_preproc=False, displacy_path=None Evaluate a model. To render a sample of parses in a HTML file, set an output directory as the displacy_path argument. """ + + util.fix_random_seed() if gpu_id >= 0: util.use_gpu(gpu_id) util.set_env_log(False) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index f8363bde1..6c7b95682 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -6,8 +6,6 @@ from pathlib import Path import tqdm from thinc.neural._classes.model import Model from timeit import default_timer as timer -import random -import numpy.random from ..gold import GoldCorpus, minibatch from ..util import prints @@ -16,9 +14,6 @@ from .. import about from .. import displacy from ..compat import json_dumps -random.seed(0) -numpy.random.seed(0) - @plac.annotations( lang=("model language", "positional", None, str), @@ -45,6 +40,7 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, """ Train a model. Expects data in spaCy's JSON format. """ + util.fix_random_seed() util.set_env_log(True) n_sents = n_sents or None output_path = util.ensure_path(output_dir) diff --git a/spacy/util.py b/spacy/util.py index 7676b33b2..b42ca3734 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -17,6 +17,7 @@ from thinc.neural._classes.model import Model import functools import cytoolz import itertools +import numpy as np from .symbols import ORTH from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ @@ -623,3 +624,8 @@ def use_gpu(gpu_id): Model.ops = CupyOps() Model.Ops = CupyOps return device + + +def fix_random_seed(seed=0): + random.seed(0) + np.random.seed(0)