diff --git a/spacy/util.py b/spacy/util.py index 0a682fcaa..ea662d3a3 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -15,6 +15,11 @@ import itertools import numpy.random import srsly +try: + import cupy.random +except ImportError: + cupy = None + from .symbols import ORTH from .compat import cupy, CudaStream, path2str, basestring_, unicode_ from .compat import import_file @@ -598,6 +603,8 @@ def use_gpu(gpu_id): def fix_random_seed(seed=0): random.seed(seed) numpy.random.seed(seed) + if cupy is not None: + cupy.random.seed(seed) class SimpleFrozenDict(dict):