remove labels from constructor

This commit is contained in:
svlandeg 2020-11-11 21:34:12 +01:00
parent fcd79e0655
commit d5a920325f
2 changed files with 13 additions and 11 deletions

View File

@ -47,7 +47,7 @@ class MultitaskObjective(Tagger):
side-objective.
"""
def __init__(self, vocab, model, name="nn_labeller", *, labels, target):
def __init__(self, vocab, model, name="nn_labeller", *, target):
self.vocab = vocab
self.model = model
self.name = name
@ -67,7 +67,7 @@ class MultitaskObjective(Tagger):
self.make_label = target
else:
raise ValueError(Errors.E016)
cfg = {"labels": labels or {}, "target": target}
cfg = {"labels": {}, "target": target}
self.cfg = dict(cfg)
@property
@ -81,15 +81,18 @@ class MultitaskObjective(Tagger):
def set_annotations(self, docs, dep_ids):
pass
def initialize(self, get_examples, nlp=None):
def initialize(self, get_examples, nlp=None, labels=None):
if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples))
raise ValueError(err)
for example in get_examples():
for token in example.y:
label = self.make_label(token)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
if labels is not None:
self.labels = labels
else:
for example in get_examples():
for token in example.y:
label = self.make_label(token)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
self.model.initialize() # TODO: fix initialization by defining X and Y
def predict(self, docs):

View File

@ -61,14 +61,13 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger
"""
def __init__(self, vocab, model, name="tagger", *, labels=None):
def __init__(self, vocab, model, name="tagger"):
"""Initialize a part-of-speech tagger.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
labels (List): The set of labels. Defaults to None.
DOCS: https://nightly.spacy.io/api/tagger#init
"""
@ -76,7 +75,7 @@ class Tagger(TrainablePipe):
self.model = model
self.name = name
self._rehearsal_model = None
cfg = {"labels": labels or []}
cfg = {"labels": []}
self.cfg = dict(sorted(cfg.items()))
@property