From 12de2638137c1c8c9f86b687d6296138e0aaa0ea Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 13 Aug 2017 09:33:39 +0200 Subject: [PATCH] Bug fixes to beam parsing. Learns small sample --- spacy/syntax/_beam_utils.pyx | 81 +++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 20 deletions(-) diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index af4aff9fe..0a513531d 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -66,7 +66,7 @@ cdef class ParserBeam(object): for beam in self.beams: if beam is not None: _cleanup(beam) - + @property def is_done(self): return all(b.is_done for b in self.beams) @@ -80,6 +80,8 @@ cdef class ParserBeam(object): def advance(self, scores, follow_gold=False): cdef Beam beam for i, beam in enumerate(self.beams): + if beam.is_done: + continue self._set_scores(beam, scores[i]) if self.golds is not None: self._set_costs(beam, self.golds[i], follow_gold=follow_gold) @@ -108,7 +110,22 @@ cdef class ParserBeam(object): for j in range(beam.nr_class): if beam.costs[i][j] >= 1: beam.is_valid[i][j] = 0 - + + +def is_gold(StateClass state, GoldParse gold, strings): + predicted = set() + truth = set() + for i in range(gold.length): + if gold.cand_to_gold[i] is None: + continue + if state.safe_get(i).dep: + predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] + truth.add((id_, head, dep)) + return truth == predicted + def get_token_ids(states, int n_tokens): cdef StateClass state @@ -123,11 +140,13 @@ def get_token_ids(states, int n_tokens): c_ids += ids.shape[1] return ids - +nr_update = 0 def update_beam(TransitionSystem moves, int nr_feature, int max_steps, states, tokvecs, golds, state2vec, vec2scores, drop=0., sgd=None, losses=None, int width=4, float density=0.001): + global nr_update + nr_update += 1 pbeam = ParserBeam(moves, states, golds, width=width, density=density) gbeam = ParserBeam(moves, states, golds, @@ -139,8 +158,9 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, if pbeam.is_done and gbeam.is_done: break beam_maps.append({}) - states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1]) - + states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update) + if not states: + break token_ids = get_token_ids(states, nr_feature) vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) @@ -154,6 +174,16 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, for i, violn in enumerate(violns): violn.check_crf(pbeam[i], gbeam[i]) + # The non-monotonic oracle makes it difficult to ensure final costs are + # correct. Therefore do final correction + cdef Beam pred + for i, (pred, gold_parse) in enumerate(zip(pbeam, golds)): + for j in range(pred.size): + if is_gold(pred.at(j), gold_parse, moves.strings): + pred._states[j].loss = 0.0 + elif pred._states[j].loss == 0.0: + pred._states[j].loss = 1.0 + violn.check_crf(pred, gbeam[i]) histories = [(v.p_hist + v.g_hist) for v in violns] losses = [(v.p_probs + v.g_probs) for v in violns] @@ -162,30 +192,35 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, return states_d_scores, backprops -def get_states(pbeams, gbeams, beam_map): +def get_states(pbeams, gbeams, beam_map, nr_update): seen = {} states = [] p_indices = [] g_indices = [] cdef Beam pbeam, gbeam for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): + if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update): + continue p_indices.append([]) for j in range(pbeam.size): - key = tuple([eg_id] + pbeam.histories[j]) - seen[key] = len(states) - p_indices[-1].append(len(states)) - states.append(pbeam.at(j)) + state = pbeam.at(j) + if not state.is_final(): + key = tuple([eg_id] + pbeam.histories[j]) + seen[key] = len(states) + p_indices[-1].append(len(states)) + states.append(pbeam.at(j)) beam_map.update(seen) g_indices.append([]) for i in range(gbeam.size): - key = tuple([eg_id] + gbeam.histories[i]) - if key in seen: - g_indices[-1].append(seen[key]) - else: - g_indices[-1].append(len(states)) - beam_map[key] = len(states) - states.append(gbeam.at(i)) - + state = gbeam.at(j) + if not state.is_final(): + key = tuple([eg_id] + gbeam.histories[i]) + if key in seen: + g_indices[-1].append(seen[key]) + else: + g_indices[-1].append(len(states)) + beam_map[key] = len(states) + states.append(gbeam.at(i)) p_indices = [numpy.asarray(idx, dtype='i') for idx in p_indices] g_indices = [numpy.asarray(idx, dtype='i') for idx in g_indices] return states, p_indices, g_indices @@ -206,12 +241,18 @@ def get_gradient(nr_class, beam_maps, histories, losses): So history is list of lists of lists of ints """ nr_step = len(beam_maps) - grads = [numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f') - for beam_map in beam_maps] + grads = [] + for beam_map in beam_maps: + if beam_map: + grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) + else: + grads.append(None) for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): key = tuple([eg_id]) for j, clas in enumerate(hist): + if grads[j] is None: + continue i = beam_maps[j][key] # In step j, at state i action clas # resulted in loss