From d5a920325f25939721f5895e3367c274cf1ecfe6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 11 Nov 2020 21:34:12 +0100 Subject: [PATCH] remove labels from constructor --- spacy/pipeline/multitask.pyx | 19 +++++++++++-------- spacy/pipeline/tagger.pyx | 5 ++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/spacy/pipeline/multitask.pyx b/spacy/pipeline/multitask.pyx index e1ea49849..9c7bb5914 100644 --- a/spacy/pipeline/multitask.pyx +++ b/spacy/pipeline/multitask.pyx @@ -47,7 +47,7 @@ class MultitaskObjective(Tagger): side-objective. """ - def __init__(self, vocab, model, name="nn_labeller", *, labels, target): + def __init__(self, vocab, model, name="nn_labeller", *, target): self.vocab = vocab self.model = model self.name = name @@ -67,7 +67,7 @@ class MultitaskObjective(Tagger): self.make_label = target else: raise ValueError(Errors.E016) - cfg = {"labels": labels or {}, "target": target} + cfg = {"labels": {}, "target": target} self.cfg = dict(cfg) @property @@ -81,15 +81,18 @@ class MultitaskObjective(Tagger): def set_annotations(self, docs, dep_ids): pass - def initialize(self, get_examples, nlp=None): + def initialize(self, get_examples, nlp=None, labels=None): if not hasattr(get_examples, "__call__"): err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples)) raise ValueError(err) - for example in get_examples(): - for token in example.y: - label = self.make_label(token) - if label is not None and label not in self.labels: - self.labels[label] = len(self.labels) + if labels is not None: + self.labels = labels + else: + for example in get_examples(): + for token in example.y: + label = self.make_label(token) + if label is not None and label not in self.labels: + self.labels[label] = len(self.labels) self.model.initialize() # TODO: fix initialization by defining X and Y def predict(self, docs): diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 16633a7b8..08f09b002 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -61,14 +61,13 @@ class Tagger(TrainablePipe): DOCS: https://nightly.spacy.io/api/tagger """ - def __init__(self, vocab, model, name="tagger", *, labels=None): + def __init__(self, vocab, model, name="tagger"): """Initialize a part-of-speech tagger. vocab (Vocab): The shared vocabulary. model (thinc.api.Model): The Thinc Model powering the pipeline component. name (str): The component instance name, used to add entries to the losses during training. - labels (List): The set of labels. Defaults to None. DOCS: https://nightly.spacy.io/api/tagger#init """ @@ -76,7 +75,7 @@ class Tagger(TrainablePipe): self.model = model self.name = name self._rehearsal_model = None - cfg = {"labels": labels or []} + cfg = {"labels": []} self.cfg = dict(sorted(cfg.items())) @property