diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 2277aaf75..2255a585a 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -275,13 +275,18 @@ class Tagger(Pipe): err = Errors.E930.format(name="Tagger", obj=type(get_examples)) raise ValueError(err) tags = set() + doc_sample = [] for example in get_examples(): for token in example.y: tags.add(token.tag_) + if len(doc_sample) < 10: + doc_sample.append(example.x) + if not doc_sample: + doc_sample.append(Doc(self.vocab, words=["hello"])) for tag in sorted(tags): self.add_label(tag) self.set_output(len(self.labels)) - self.model.initialize() + self.model.initialize(X=doc_sample) if sgd is None: sgd = self.create_optimizer() return sgd