mirror of https://github.com/explosion/spaCy.git
Use labels in tagger
This commit is contained in:
parent
ca72608059
commit
99bff78617
|
@ -266,7 +266,7 @@ class Tagger(Pipe):
|
|||
raise ValueError("nan value when computing loss")
|
||||
return float(loss), d_scores
|
||||
|
||||
def initialize(self, get_examples, *, nlp=None):
|
||||
def initialize(self, get_examples, *, nlp=None, labels=None):
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
||||
|
@ -277,15 +277,19 @@ class Tagger(Pipe):
|
|||
DOCS: https://nightly.spacy.io/api/tagger#initialize
|
||||
"""
|
||||
self._ensure_examples(get_examples)
|
||||
if labels is not None:
|
||||
for tag in labels:
|
||||
self.add_label(tag)
|
||||
else:
|
||||
tags = set()
|
||||
for example in get_examples():
|
||||
for token in example.y:
|
||||
if token.tag_:
|
||||
tags.add(token.tag_)
|
||||
for tag in sorted(tags):
|
||||
self.add_label(tag)
|
||||
doc_sample = []
|
||||
label_sample = []
|
||||
tags = set()
|
||||
for example in get_examples():
|
||||
for token in example.y:
|
||||
if token.tag_:
|
||||
tags.add(token.tag_)
|
||||
for tag in sorted(tags):
|
||||
self.add_label(tag)
|
||||
for example in islice(get_examples(), 10):
|
||||
doc_sample.append(example.x)
|
||||
gold_tags = example.get_aligned("TAG", as_string=True)
|
||||
|
|
Loading…
Reference in New Issue