Force tagger to pass batch of docs into model in begin_training

This commit is contained in:
Matthew Honnibal 2020-08-27 03:21:03 +02:00
parent 9b22714a4e
commit 95adb58f15
1 changed files with 6 additions and 1 deletions

View File

@ -275,13 +275,18 @@ class Tagger(Pipe):
err = Errors.E930.format(name="Tagger", obj=type(get_examples)) err = Errors.E930.format(name="Tagger", obj=type(get_examples))
raise ValueError(err) raise ValueError(err)
tags = set() tags = set()
doc_sample = []
for example in get_examples(): for example in get_examples():
for token in example.y: for token in example.y:
tags.add(token.tag_) tags.add(token.tag_)
if len(doc_sample) < 10:
doc_sample.append(example.x)
if not doc_sample:
doc_sample.append(Doc(self.vocab, words=["hello"]))
for tag in sorted(tags): for tag in sorted(tags):
self.add_label(tag) self.add_label(tag)
self.set_output(len(self.labels)) self.set_output(len(self.labels))
self.model.initialize() self.model.initialize(X=doc_sample)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd