mirror of https://github.com/explosion/spaCy.git
Add itershuffle utility function. Maybe belongs in thinc
This commit is contained in:
parent
3b7c108246
commit
0731971bfc
|
@ -9,6 +9,7 @@ import regex as re
|
|||
from pathlib import Path
|
||||
import sys
|
||||
import textwrap
|
||||
import random
|
||||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||
|
@ -172,6 +173,31 @@ def get_async(stream, numpy_array):
|
|||
array.set(numpy_array, stream=stream)
|
||||
return array
|
||||
|
||||
def itershuffle(iterable, bufsize=1000):
|
||||
"""Shuffle an iterator. This works by holding `bufsize` items back
|
||||
and yielding them sometime later. Obviously, this is not unbiased --
|
||||
but should be good enough for batching. Larger bufsize means less bias.
|
||||
|
||||
From https://gist.github.com/andres-erbsen/1307752
|
||||
"""
|
||||
iterable = iter(iterable)
|
||||
buf = []
|
||||
try:
|
||||
while True:
|
||||
for i in range(random.randint(1, bufsize-len(buf))):
|
||||
buf.append(iterable.next())
|
||||
random.shuffle(buf)
|
||||
for i in range(random.randint(1, bufsize)):
|
||||
if buf:
|
||||
yield buf.pop()
|
||||
else:
|
||||
break
|
||||
except StopIteration:
|
||||
random.shuffle(buf)
|
||||
while buf:
|
||||
yield buf.pop()
|
||||
raise StopIteration
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
if type(default) is float:
|
||||
|
|
Loading…
Reference in New Issue