diff --git a/spacy/errors.py b/spacy/errors.py index 9c945fb0c..45cabc4ad 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -287,6 +287,8 @@ class Errors(object): E108 = ("As of spaCy v2.1, the pipe name `sbd` has been deprecated " "in favor of the pipe name `sentencizer`, which does the same " "thing. For example, use `nlp.create_pipeline('sentencizer')`") + E109 = ("Model for component '{name}' not initialized. Did you forget to load " + "a model, or forget to call begin_training()?") @add_codes diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 81cf1ea69..38cdbbd92 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -293,10 +293,16 @@ class Pipe(object): Both __call__ and pipe should delegate to the `predict()` and `set_annotations()` methods. """ + self.require_model() scores, tensors = self.predict([doc]) self.set_annotations([doc], scores, tensors=tensors) return doc + def require_model(self): + """Raise an error if the component's model is not initialized.""" + if getattr(self, 'model', None) in (None, True, False): + raise ValueError(Errors.E109.format(name=self.name)) + def pipe(self, stream, batch_size=128, n_threads=-1): """Apply the pipe to a stream of documents. @@ -313,6 +319,7 @@ class Pipe(object): """Apply the pipeline's model to a batch of docs, without modifying them. """ + self.require_model() raise NotImplementedError def set_annotations(self, docs, scores, tensors=None): @@ -325,6 +332,7 @@ class Pipe(object): Delegates to predict() and get_loss(). """ + self.require_model() raise NotImplementedError def rehearse(self, docs, sgd=None, losses=None, **config): @@ -495,6 +503,7 @@ class Tensorizer(Pipe): docs (iterable): A sequence of `Doc` objects. RETURNS (object): Vector representations for each token in the docs. """ + self.require_model() inputs = self.model.ops.flatten([doc.tensor for doc in docs]) outputs = self.model(inputs) return self.model.ops.unflatten(outputs, [len(d) for d in docs]) @@ -519,6 +528,7 @@ class Tensorizer(Pipe): sgd (callable): An optimizer. RETURNS (dict): Results from the update. """ + self.require_model() if isinstance(docs, Doc): docs = [docs] inputs = [] @@ -600,6 +610,7 @@ class Tagger(Pipe): yield from docs def predict(self, docs): + self.require_model() if not any(len(doc) for doc in docs): # Handle case where there are no tokens in any docs. n_labels = len(self.labels) @@ -644,6 +655,7 @@ class Tagger(Pipe): doc.is_tagged = True def update(self, docs, golds, drop=0., sgd=None, losses=None): + self.require_model() if losses is not None and self.name not in losses: losses[self.name] = 0. @@ -904,6 +916,7 @@ class MultitaskObjective(Tagger): return model def predict(self, docs): + self.require_model() tokvecs = self.model.tok2vec(docs) scores = self.model.softmax(tokvecs) return tokvecs, scores @@ -1042,6 +1055,7 @@ class ClozeMultitask(Pipe): return sgd def predict(self, docs): + self.require_model() tokvecs = self.model.tok2vec(docs) vectors = self.model.output_layer(tokvecs) return tokvecs, vectors @@ -1061,6 +1075,7 @@ class ClozeMultitask(Pipe): pass def rehearse(self, docs, drop=0., sgd=None, losses=None): + self.require_model() if losses is not None and self.name not in losses: losses[self.name] = 0. predictions, bp_predictions = self.model.begin_update(docs, drop=drop) @@ -1105,9 +1120,11 @@ class SimilarityHook(Pipe): yield self(doc) def predict(self, doc1, doc2): + self.require_model() return self.model.predict([(doc1, doc2)]) def update(self, doc1_doc2, golds, sgd=None, drop=0.): + self.require_model() sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop) def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): @@ -1171,6 +1188,7 @@ class TextCategorizer(Pipe): yield from docs def predict(self, docs): + self.require_model() scores = self.model(docs) scores = self.model.ops.asarray(scores) tensors = [doc.tensor for doc in docs] diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index fd87aab2b..657e30f41 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -205,7 +205,9 @@ class ParserModel(Model): return smaller = self.upper larger = Affine(new_output, smaller.nI) - larger.W *= 0 + # Set nan as value for unseen classes, to prevent prediction. + larger.W.fill(self.ops.xp.nan) + larger.b.fill(self.ops.xp.nan) # It seems very unhappy if I pass these as smaller.W? # Seems to segfault. Maybe it's a descriptor protocol thing? smaller_W = smaller.W @@ -254,8 +256,23 @@ class ParserStepModel(Model): if mask is not None: vector *= mask scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop) + # We can have nans from unseen classes. + # For backprop purposes, we want to treat unseen classes as having the + # lowest score. + # numpy's nan_to_num function doesn't take a value, and nan is replaced + # by 0...-inf is replaced by minimum, so we go via that. Ugly to the max. + scores[self.ops.xp.isnan(scores)] = -self.ops.xp.inf + self.ops.xp.nan_to_num(scores, copy=False) def backprop_parser_step(d_scores, sgd=None): + # If we have a non-zero gradient for a previously unseen class, + # replace the weight with 0. + new_classes = self.ops.xp.logical_and( + self.vec2scores.ops.xp.isnan(self.vec2scores.b), + d_scores.any(axis=0) + ) + self.vec2scores.b[new_classes] = 0. + self.vec2scores.W[new_classes] = 0. d_vector = get_d_vector(d_scores, sgd=sgd) if mask is not None: d_vector *= mask @@ -400,6 +417,8 @@ cdef class precompute_hiddens: state_vector, mask = self.ops.maxout(state_vector) def backprop_nonlinearity(d_best, sgd=None): + # Fix nans (which can occur from unseen classes.) + d_best[self.ops.xp.isnan(d_best)] = 0. if self.nP == 1: d_best *= mask d_best = d_best.reshape((d_best.shape + (1,))) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 52bf89128..95fe5f997 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -226,8 +226,14 @@ cdef class Parser: self.set_annotations(subbatch, parse_states, tensors=None) for doc in batch_in_order: yield doc + + def require_model(self): + """Raise an error if the component's model is not initialized.""" + if getattr(self, 'model', None) in (None, True, False): + raise ValueError(Errors.E109.format(name=self.name)) def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.): + self.require_model() if isinstance(docs, Doc): docs = [docs] if not any(len(doc) for doc in docs): @@ -375,6 +381,7 @@ cdef class Parser: return [b for b in beams if not b.is_done] def update(self, docs, golds, drop=0., sgd=None, losses=None): + self.require_model() if isinstance(docs, Doc) and isinstance(golds, GoldParse): docs = [docs] golds = [golds]