From 95adb58f155dd3cdaba244db38be35c8ac837c33 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 27 Aug 2020 03:21:03 +0200 Subject: [PATCH] Force tagger to pass batch of docs into model in begin_training --- spacy/pipeline/tagger.pyx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 2277aaf75..2255a585a 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -275,13 +275,18 @@ class Tagger(Pipe): err = Errors.E930.format(name="Tagger", obj=type(get_examples)) raise ValueError(err) tags = set() + doc_sample = [] for example in get_examples(): for token in example.y: tags.add(token.tag_) + if len(doc_sample) < 10: + doc_sample.append(example.x) + if not doc_sample: + doc_sample.append(Doc(self.vocab, words=["hello"])) for tag in sorted(tags): self.add_label(tag) self.set_output(len(self.labels)) - self.model.initialize() + self.model.initialize(X=doc_sample) if sgd is None: sgd = self.create_optimizer() return sgd