diff --git a/requirements.txt b/requirements.txt index 8cc52dfe4..58761b95c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # Our libraries cymem>=2.0.2,<2.1.0 preshed>=2.0.1,<2.1.0 -thinc>=7.0.2,<7.1.0 +thinc>=7.0.5,<7.1.0 blis>=0.2.2,<0.3.0 murmurhash>=0.28.0,<1.1.0 wasabi>=0.2.0,<1.1.0 diff --git a/spacy/errors.py b/spacy/errors.py index 8f2eab3a1..347ad1fca 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -403,6 +403,7 @@ class Errors(object): E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.") E141 = ("Entity vectors should be of length {required} instead of the provided {found}.") E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'") + E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?") @add_codes diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index d99a1f73e..891e8d4e3 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -902,6 +902,11 @@ class TextCategorizer(Pipe): def labels(self): return tuple(self.cfg.setdefault("labels", [])) + def require_labels(self): + """Raise an error if the component's model has no labels defined.""" + if not self.labels: + raise ValueError(Errors.E143.format(name=self.name)) + @labels.setter def labels(self, value): self.cfg["labels"] = tuple(value) @@ -931,6 +936,7 @@ class TextCategorizer(Pipe): doc.cats[label] = float(scores[i, j]) def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None): + self.require_model() scores, bp_scores = self.model.begin_update(docs, drop=drop) loss, d_scores = self.get_loss(docs, golds, scores) bp_scores(d_scores, sgd=sgd) @@ -985,6 +991,7 @@ class TextCategorizer(Pipe): def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): if self.model is True: self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors") + self.require_labels() self.model = self.Model(len(self.labels), **self.cfg) link_vectors_to_models(self.vocab) if sgd is None: