Use labels in tagger

This commit is contained in:
Matthew Honnibal 2020-09-29 16:48:44 +02:00
parent ca72608059
commit 99bff78617
1 changed files with 12 additions and 8 deletions

View File

@ -266,7 +266,7 @@ class Tagger(Pipe):
raise ValueError("nan value when computing loss") raise ValueError("nan value when computing loss")
return float(loss), d_scores return float(loss), d_scores
def initialize(self, get_examples, *, nlp=None): def initialize(self, get_examples, *, nlp=None, labels=None):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -277,15 +277,19 @@ class Tagger(Pipe):
DOCS: https://nightly.spacy.io/api/tagger#initialize DOCS: https://nightly.spacy.io/api/tagger#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
if labels is not None:
for tag in labels:
self.add_label(tag)
else:
tags = set()
for example in get_examples():
for token in example.y:
if token.tag_:
tags.add(token.tag_)
for tag in sorted(tags):
self.add_label(tag)
doc_sample = [] doc_sample = []
label_sample = [] label_sample = []
tags = set()
for example in get_examples():
for token in example.y:
if token.tag_:
tags.add(token.tag_)
for tag in sorted(tags):
self.add_label(tag)
for example in islice(get_examples(), 10): for example in islice(get_examples(), 10):
doc_sample.append(example.x) doc_sample.append(example.x)
gold_tags = example.get_aligned("TAG", as_string=True) gold_tags = example.get_aligned("TAG", as_string=True)