diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index ada224b3b..456739bf5 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -78,7 +78,7 @@ def make_update(model, docs, optimizer, drop=0.): return loss -def make_docs(nlp, batch): +def make_docs(nlp, batch, min_length=1, max_length=500): docs = [] for record in batch: text = record["text"] @@ -91,7 +91,7 @@ def make_docs(nlp, batch): heads = numpy.asarray(heads, dtype="uint64") heads = heads.reshape((len(doc), 1)) doc = doc.from_array([HEAD], heads) - if len(doc) >= 1 and len(doc) < 200: + if len(doc) >= min_length and len(doc) < max_length: docs.append(doc) return docs