""" MALT-style dependency parser """ # cython: profile=True # cython: experimental_cpp_class_def=True # cython: cdivision=True # cython: infer_types=True # coding: utf-8 from __future__ import unicode_literals, print_function cimport cython from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from libc.stdint cimport uint32_t, uint64_t from libc.string cimport memset, memcpy from libc.stdlib cimport rand from libc.math cimport log, exp, isnan, isinf from cymem.cymem cimport Pool, Address from murmurhash.mrmr cimport real_hash64 as hash64 from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.linear.features cimport ConjunctionExtracter from thinc.structs cimport FeatureC, ExampleC from thinc.extra.search cimport Beam, MaxViolation from thinc.extra.eg cimport Example from thinc.extra.mb cimport Minibatch from ..structs cimport TokenC from ..tokens.doc cimport Doc from ..strings cimport StringStore from .transition_system cimport TransitionSystem, Transition from ..gold cimport GoldParse from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from .parser cimport Parser from ._beam_utils import is_gold DEBUG = False def set_debug(val): global DEBUG DEBUG = val def get_templates(name): pf = _parse_features if name == 'ner': return pf.ner elif name == 'debug': return pf.unigrams else: return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ pf.tree_shape + pf.trigrams) cdef int BEAM_WIDTH = 16 cdef weight_t BEAM_DENSITY = 0.001 cdef class BeamParser(Parser): def __init__(self, *args, **kwargs): self.beam_width = kwargs.get('beam_width', BEAM_WIDTH) self.beam_density = kwargs.get('beam_density', BEAM_DENSITY) Parser.__init__(self, *args, **kwargs) cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil: with gil: self._parseC(tokens, length, nr_feat, self.moves.n_moves) cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1: cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) # TODO: How do we handle new labels here? This increases nr_class beam.initialize(self.moves.init_beam_state, length, tokens) beam.check_done(_check_final_state, NULL) if beam.is_done: _cleanup(beam) return 0 while not beam.is_done: self._advance_beam(beam, None, False) state = beam.at(0) self.moves.finalize_state(state.c) for i in range(length): tokens[i] = state.c._sent[i] _cleanup(beam) def update(self, Doc tokens, GoldParse gold_parse, itn=0): self.moves.preprocess_gold(gold_parse) cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c) pred.check_done(_check_final_state, NULL) # Hack for NER for i in range(pred.size): stcls = pred.at(i) self.moves.initialize_state(stcls.c) cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c) gold.check_done(_check_final_state, NULL) violn = MaxViolation() while not pred.is_done and not gold.is_done: # We search separately here, to allow for ambiguity in the gold parse. self._advance_beam(pred, gold_parse, False) self._advance_beam(gold, gold_parse, True) violn.check_crf(pred, gold) if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): break else: # The non-monotonic oracle makes it difficult to ensure final costs are # correct. Therefore do final correction for i in range(pred.size): if is_gold(pred.at(i), gold_parse, self.moves.strings): pred._states[i].loss = 0.0 elif pred._states[i].loss == 0.0: pred._states[i].loss = 1.0 violn.check_crf(pred, gold) if pred.size < 1: raise Exception("No candidates", tokens.length) if gold.size < 1: raise Exception("No gold", tokens.length) if pred.loss == 0: self.model.update_from_histories(self.moves, tokens, [(0.0, [])]) elif True: #_check_train_integrity(pred, gold, gold_parse, self.moves) histories = list(zip(violn.p_probs, violn.p_hist)) + \ list(zip(violn.g_probs, violn.g_hist)) self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1)) else: self.model.update_from_histories(self.moves, tokens, [(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])]) _cleanup(pred) _cleanup(gold) return pred.loss def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): cdef atom_t[CONTEXT_SIZE] context cdef Pool mem = Pool() features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) if False: mb = Minibatch(self.model.widths, beam.size) for i in range(beam.size): stcls = beam.at(i) if stcls.c.is_final(): nr_feat = 0 else: nr_feat = self.model.set_featuresC(context, features, stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c) mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) self.model(mb) for i in range(beam.size): memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) else: for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): nr_feat = self.model.set_featuresC(context, features, stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c) self.model.set_scoresC(beam.scores[i], features, nr_feat) if gold is not None: n_gold = 0 lines = [] for i in range(beam.size): stcls = beam.at(i) if not stcls.c.is_final(): self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) if follow_gold: for j in range(self.moves.n_moves): if beam.costs[i][j] >= 1: beam.is_valid[i][j] = 0 lines.append((stcls.B(0), stcls.B(1), stcls.B_(0).ent_iob, stcls.B_(1).ent_iob, stcls.B_(1).sent_start, j, beam.is_valid[i][j], 'set invalid', beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label)) n_gold += 1 if beam.is_valid[i][j] else 0 if follow_gold and n_gold == 0: raise Exception("No gold") if follow_gold: beam.advance(_transition_state, NULL, self.moves.c) else: beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) # These are passed as callbacks to thinc.search.Beam cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: dest = _dest src = _src moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) cdef int _check_final_state(void* _state, void* extra_args) except -1: return (_state).is_final() def _cleanup(Beam beam): for i in range(beam.width): Py_XDECREF(beam._states[i].content) Py_XDECREF(beam._parents[i].content) cdef hash_t _hash_state(void* _state, void* _) except 0: state = _state if state.c.is_final(): return 1 else: return state.c.hash() def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves): for i in range(pred.size): if not pred._states[i].is_done or pred._states[i].loss == 0: continue state = pred.at(i) if is_gold(state, gold_parse, moves.strings) == True: for dep in gold_parse.orig_annot: print(dep[1], dep[3], dep[4]) print("Cost", pred._states[i].loss) for j in range(gold_parse.length): print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) acts = [moves.c[clas].move for clas in pred.histories[i]] labels = [moves.c[clas].label for clas in pred.histories[i]] print([moves.move_name(move, label) for move, label in zip(acts, labels)]) raise Exception("Predicted state is gold-standard") for i in range(gold.size): if not gold._states[i].is_done: continue state = gold.at(i) if is_gold(state, gold_parse, moves.strings) == False: print("Truth") for dep in gold_parse.orig_annot: print(dep[1], dep[3], dep[4]) print("Predicted good") for j in range(gold_parse.length): print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) raise Exception("Gold parse is not gold-standard")