From 84b7ed49e48ca75ca1fea4e1c76218e0643e087c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 20 Aug 2017 14:41:38 +0200 Subject: [PATCH] Ensure updates aren't made if no gold available --- spacy/gold.pyx | 4 ++-- spacy/syntax/ner.pyx | 2 +- spacy/syntax/nn_parser.pyx | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 39951447c..096f265a9 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -406,11 +406,11 @@ cdef class GoldParse: if tags is None: tags = [None for _ in doc] if heads is None: - heads = [token.i for token in doc] + heads = [None for token in doc] if deps is None: deps = [None for _ in doc] if entities is None: - entities = ['-' for _ in doc] + entities = [None for _ in doc] elif len(entities) == 0: entities = ['O' for _ in doc] elif not isinstance(entities[0], basestring): diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index d15de0181..2f5cd4e48 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -113,7 +113,7 @@ cdef class BiluoPushDown(TransitionSystem): def has_gold(self, GoldParse gold, start=0, end=None): end = end or len(gold.ner) - if all([tag == '-' for tag in gold.ner[start:end]]): + if all([tag in ('-', None) for tag in gold.ner[start:end]]): return False else: return True diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 7412ebeee..f1a0bc91c 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -483,6 +483,9 @@ cdef class Parser: return beams def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): + docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) + if not golds: + return None if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5: return self.update_beam(docs_tokvecs, golds, self.cfg['beam_width'], self.cfg['beam_density'], @@ -555,6 +558,9 @@ cdef class Parser: def update_beam(self, docs_tokvecs, golds, width=None, density=None, drop=0., sgd=None, losses=None): + docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) + if not golds: + return None if width is None: width = self.cfg.get('beam_width', 2) if density is None: @@ -605,6 +611,15 @@ cdef class Parser: bp_my_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs + def _filter_unlabelled(self, docs_tokvecs, golds): + '''Remove inputs that have no relevant labels before update''' + has_golds = [self.moves.has_gold(gold) for gold in golds] + docs, tokvecs = docs_tokvecs + docs = [docs[i] for i, has_gold in enumerate(has_golds) if has_gold] + tokvecs = [tokvecs[i] for i, has_gold in enumerate(has_golds) if has_gold] + golds = [golds[i] for i, has_gold in enumerate(has_golds) if has_gold] + return (docs, tokvecs), golds + def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long doc will get multiple states. Let's say we have a doc of length 2*N,