diff --git a/spacy/cli/train.py b/spacy/cli/train.py index d9ab8eca5..2bfa5c56e 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -186,18 +186,12 @@ def train( def create_train_batches(iterator, batcher, max_epochs: int): - epoch = 1 - examples = [] - # Stream the first epoch, so we start training faster and support - # infinite streams. - for batch in batcher(iterator): - yield epoch, batch - if max_epochs != 1: - examples.extend(batch) + epoch = 0 + examples = list(iterator) if not examples: # Raise error if no data raise ValueError(Errors.E986) - while epoch != max_epochs: + while max_epochs < 1 or epoch != max_epochs: random.shuffle(examples) for batch in batcher(examples): yield epoch, batch