diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 253b6f08c..f4e8ecebd 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -266,7 +266,7 @@ class Tagger(Pipe): raise ValueError("nan value when computing loss") return float(loss), d_scores - def initialize(self, get_examples, *, nlp=None): + def initialize(self, get_examples, *, nlp=None, labels=None): """Initialize the pipe for training, using a representative set of data examples. @@ -277,15 +277,19 @@ class Tagger(Pipe): DOCS: https://nightly.spacy.io/api/tagger#initialize """ self._ensure_examples(get_examples) + if labels is not None: + for tag in labels: + self.add_label(tag) + else: + tags = set() + for example in get_examples(): + for token in example.y: + if token.tag_: + tags.add(token.tag_) + for tag in sorted(tags): + self.add_label(tag) doc_sample = [] label_sample = [] - tags = set() - for example in get_examples(): - for token in example.y: - if token.tag_: - tags.add(token.tag_) - for tag in sorted(tags): - self.add_label(tag) for example in islice(get_examples(), 10): doc_sample.append(example.x) gold_tags = example.get_aligned("TAG", as_string=True)