diff --git a/spacy/util.py b/spacy/util.py index f481acb5f..2de90e558 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -9,6 +9,7 @@ import regex as re from pathlib import Path import sys import textwrap +import random from .symbols import ORTH from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ @@ -172,6 +173,31 @@ def get_async(stream, numpy_array): array.set(numpy_array, stream=stream) return array +def itershuffle(iterable, bufsize=1000): + """Shuffle an iterator. This works by holding `bufsize` items back + and yielding them sometime later. Obviously, this is not unbiased -- + but should be good enough for batching. Larger bufsize means less bias. + + From https://gist.github.com/andres-erbsen/1307752 + """ + iterable = iter(iterable) + buf = [] + try: + while True: + for i in range(random.randint(1, bufsize-len(buf))): + buf.append(iterable.next()) + random.shuffle(buf) + for i in range(random.randint(1, bufsize)): + if buf: + yield buf.pop() + else: + break + except StopIteration: + random.shuffle(buf) + while buf: + yield buf.pop() + raise StopIteration + def env_opt(name, default=None): if type(default) is float: