mirror of https://github.com/explosion/spaCy.git
remove labels from constructor
This commit is contained in:
parent
fcd79e0655
commit
d5a920325f
|
@ -47,7 +47,7 @@ class MultitaskObjective(Tagger):
|
|||
side-objective.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, model, name="nn_labeller", *, labels, target):
|
||||
def __init__(self, vocab, model, name="nn_labeller", *, target):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -67,7 +67,7 @@ class MultitaskObjective(Tagger):
|
|||
self.make_label = target
|
||||
else:
|
||||
raise ValueError(Errors.E016)
|
||||
cfg = {"labels": labels or {}, "target": target}
|
||||
cfg = {"labels": {}, "target": target}
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@property
|
||||
|
@ -81,15 +81,18 @@ class MultitaskObjective(Tagger):
|
|||
def set_annotations(self, docs, dep_ids):
|
||||
pass
|
||||
|
||||
def initialize(self, get_examples, nlp=None):
|
||||
def initialize(self, get_examples, nlp=None, labels=None):
|
||||
if not hasattr(get_examples, "__call__"):
|
||||
err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples))
|
||||
raise ValueError(err)
|
||||
for example in get_examples():
|
||||
for token in example.y:
|
||||
label = self.make_label(token)
|
||||
if label is not None and label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
if labels is not None:
|
||||
self.labels = labels
|
||||
else:
|
||||
for example in get_examples():
|
||||
for token in example.y:
|
||||
label = self.make_label(token)
|
||||
if label is not None and label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||
|
||||
def predict(self, docs):
|
||||
|
|
|
@ -61,14 +61,13 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/tagger
|
||||
"""
|
||||
def __init__(self, vocab, model, name="tagger", *, labels=None):
|
||||
def __init__(self, vocab, model, name="tagger"):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (List): The set of labels. Defaults to None.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/tagger#init
|
||||
"""
|
||||
|
@ -76,7 +75,7 @@ class Tagger(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": labels or []}
|
||||
cfg = {"labels": []}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue