mirror of https://github.com/explosion/spaCy.git
more friendly textcat errors (#3946)
* more friendly textcat errors with require_model and require_labels * update thinc version with recent bugfix
This commit is contained in:
parent
b94c5443d9
commit
c4c21cb428
|
@ -1,7 +1,7 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=2.0.1,<2.1.0
|
preshed>=2.0.1,<2.1.0
|
||||||
thinc>=7.0.2,<7.1.0
|
thinc>=7.0.5,<7.1.0
|
||||||
blis>=0.2.2,<0.3.0
|
blis>=0.2.2,<0.3.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
wasabi>=0.2.0,<1.1.0
|
wasabi>=0.2.0,<1.1.0
|
||||||
|
|
|
@ -403,6 +403,7 @@ class Errors(object):
|
||||||
E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.")
|
E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.")
|
||||||
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
||||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
||||||
|
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -902,6 +902,11 @@ class TextCategorizer(Pipe):
|
||||||
def labels(self):
|
def labels(self):
|
||||||
return tuple(self.cfg.setdefault("labels", []))
|
return tuple(self.cfg.setdefault("labels", []))
|
||||||
|
|
||||||
|
def require_labels(self):
|
||||||
|
"""Raise an error if the component's model has no labels defined."""
|
||||||
|
if not self.labels:
|
||||||
|
raise ValueError(Errors.E143.format(name=self.name))
|
||||||
|
|
||||||
@labels.setter
|
@labels.setter
|
||||||
def labels(self, value):
|
def labels(self, value):
|
||||||
self.cfg["labels"] = tuple(value)
|
self.cfg["labels"] = tuple(value)
|
||||||
|
@ -931,6 +936,7 @@ class TextCategorizer(Pipe):
|
||||||
doc.cats[label] = float(scores[i, j])
|
doc.cats[label] = float(scores[i, j])
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
|
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
|
||||||
|
self.require_model()
|
||||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||||
loss, d_scores = self.get_loss(docs, golds, scores)
|
loss, d_scores = self.get_loss(docs, golds, scores)
|
||||||
bp_scores(d_scores, sgd=sgd)
|
bp_scores(d_scores, sgd=sgd)
|
||||||
|
@ -985,6 +991,7 @@ class TextCategorizer(Pipe):
|
||||||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
|
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
|
||||||
|
self.require_labels()
|
||||||
self.model = self.Model(len(self.labels), **self.cfg)
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
|
|
Loading…
Reference in New Issue