From 931feb33601364c3ed06d0c9eeaeaea85ca9cad3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 11 Mar 2017 11:12:01 -0600 Subject: [PATCH] Allow beam parsing for NER --- spacy/syntax/arc_eager.pyx | 17 +++++++++++ spacy/syntax/beam_parser.pyx | 47 ++++++++++++++++-------------- spacy/syntax/ner.pyx | 13 +++++---- spacy/syntax/transition_system.pxd | 2 ++ spacy/syntax/transition_system.pyx | 18 +++++++++--- 5 files changed, 65 insertions(+), 32 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index ba6e3af04..7049b8595 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -2,6 +2,7 @@ # cython: cdivision=True # cython: infer_types=True from __future__ import unicode_literals +from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF import ctypes import os @@ -293,7 +294,23 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil: return word +cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: + cdef StateClass st = StateClass.init(tokens, length) + # Ensure sent_start is set to 0 throughout + for i in range(st.c.length): + st.c._sent[i].sent_start = False + st.c._sent[i].l_edge = i + st.c._sent[i].r_edge = i + st.fast_forward() + Py_INCREF(st) + return st + + cdef class ArcEager(TransitionSystem): + def __init__(self, *args, **kwargs): + TransitionSystem.__init__(self, *args, **kwargs) + self.init_beam_state = _init_state + @classmethod def get_actions(cls, **kwargs): actions = kwargs.get('actions', diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index 29362d845..8f4f65186 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -83,7 +83,7 @@ cdef class BeamParser(Parser): cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1: cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) - beam.initialize(_init_state, length, tokens) + beam.initialize(self.moves.init_beam_state, length, tokens) beam.check_done(_check_final_state, NULL) if beam.is_done: _cleanup(beam) @@ -99,14 +99,17 @@ cdef class BeamParser(Parser): def update(self, Doc tokens, GoldParse gold_parse, itn=0): self.moves.preprocess_gold(gold_parse) cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) - pred.initialize(_init_state, tokens.length, tokens.c) + pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c) pred.check_done(_check_final_state, NULL) + # Hack for NER + for i in range(pred.size): + stcls = pred.at(i) + self.moves.initialize_state(stcls.c) cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) - gold.initialize(_init_state, tokens.length, tokens.c) + gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c) gold.check_done(_check_final_state, NULL) violn = MaxViolation() - itn = 0 while not pred.is_done and not gold.is_done: # We search separately here, to allow for ambiguity in the gold parse. self._advance_beam(pred, gold_parse, False) @@ -114,7 +117,6 @@ cdef class BeamParser(Parser): violn.check_crf(pred, gold) if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): break - itn += 1 else: # The non-monotonic oracle makes it difficult to ensure final costs are # correct. Therefore do final correction @@ -124,8 +126,10 @@ cdef class BeamParser(Parser): elif pred._states[i].loss == 0.0: pred._states[i].loss = 1.0 violn.check_crf(pred, gold) - assert pred.size >= 1 - assert gold.size >= 1 + if pred.size < 1: + raise Exception("No candidates", tokens.length) + if gold.size < 1: + raise Exception("No gold", tokens.length) if pred.loss == 0: self.model.update_from_histories(self.moves, tokens, [(0.0, [])]) elif True: @@ -164,17 +168,29 @@ cdef class BeamParser(Parser): self.moves.set_valid(beam.is_valid[i], stcls.c) self.model.set_scoresC(beam.scores[i], features, nr_feat) if gold is not None: + n_gold = 0 + lines = [] for i in range(beam.size): stcls = beam.at(i) if not stcls.c.is_final(): self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) if follow_gold: for j in range(self.moves.n_moves): - beam.is_valid[i][j] *= beam.costs[i][j] <= 0 + if beam.costs[i][j] >= 1: + beam.is_valid[i][j] = 0 + lines.append((stcls.B(0), stcls.B(1), + stcls.B_(0).ent_iob, stcls.B_(1).ent_iob, + stcls.B_(1).sent_start, + j, + beam.is_valid[i][j], 'set invalid', + beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label)) + n_gold += 1 if beam.is_valid[i][j] else 0 + if follow_gold and n_gold == 0: + raise Exception("No gold") if follow_gold: beam.advance(_transition_state, NULL, self.moves.c) else: - beam.advance(_transition_state, _hash_state, self.moves.c) + beam.advance(_transition_state, NULL, self.moves.c) beam.check_done(_check_final_state, NULL) @@ -187,19 +203,6 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves[clas].do(dest.c, moves[clas].label) -cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: - cdef StateClass st = StateClass.init(tokens, length) - # Ensure sent_start is set to 0 throughout - # Ensure sent_start is set to 0 throughout - for i in range(st.c.length): - st.c._sent[i].sent_start = False - st.c._sent[i].l_edge = i - st.c._sent[i].r_edge = i - st.fast_forward() - Py_INCREF(st) - return st - - cdef int _check_final_state(void* _state, void* extra_args) except -1: return (_state).is_final() diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 8f6ecde9f..dcd53f694 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -52,7 +52,7 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil: cdef class BiluoPushDown(TransitionSystem): @classmethod def get_actions(cls, **kwargs): - actions = kwargs.get('actions', + actions = kwargs.get('actions', { MISSING: {'': True}, BEGIN: {}, @@ -159,6 +159,7 @@ cdef class BiluoPushDown(TransitionSystem): return t cdef int initialize_state(self, StateC* st) nogil: + # This is especially necessary when we use limited training data. for i in range(st.length): if st._sent[i].ent_type != 0: with gil: @@ -248,7 +249,7 @@ cdef class In: elif st.B_(1).sent_start: return False return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label - + @staticmethod cdef int transition(StateC* st, int label) nogil: st.set_ent_tag(st.B(0), 1, label) @@ -262,7 +263,7 @@ cdef class In: cdef int g_act = gold.ner[s.B(0)].move cdef int g_tag = gold.ner[s.B(0)].label cdef bint is_sunk = _entity_is_sunk(s, gold.ner) - + if g_act == MISSING: return 0 elif g_act == BEGIN: @@ -304,7 +305,7 @@ cdef class Last: cdef int g_act = gold.ner[s.B(0)].move cdef int g_tag = gold.ner[s.B(0)].label - + if g_act == MISSING: return 0 elif g_act == BEGIN: @@ -381,7 +382,7 @@ cdef class Out: st.set_ent_tag(st.B(0), 2, 0) st.push() st.pop() - + @staticmethod cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef int g_act = gold.ner[s.B(0)].move @@ -406,7 +407,7 @@ cdef class Out: return 1 else: return 1 - + class OracleError(Exception): pass diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index b985498dc..5169ff7ca 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -28,6 +28,7 @@ ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, ctypedef int (*do_func_t)(StateC* state, int label) nogil +ctypedef void* (*init_state_t)(Pool mem, int length, void* tokens) except NULL cdef class TransitionSystem: cdef Pool mem @@ -37,6 +38,7 @@ cdef class TransitionSystem: cdef int _size cdef public int root_label cdef public freqs + cdef init_state_t init_beam_state cdef int initialize_state(self, StateC* state) nogil cdef int finalize_state(self, StateC* state) nogil diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index a4bb02753..3c222288d 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -6,6 +6,7 @@ from collections import defaultdict from ..structs cimport TokenC from .stateclass cimport StateClass from ..attrs cimport TAG, HEAD, DEP, ENT_TYPE, ENT_IOB +from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF cdef weight_t MIN_SCORE = -90000 @@ -15,19 +16,27 @@ class OracleError(Exception): pass +cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: + cdef StateClass st = StateClass.init(tokens, length) + # Ensure sent_start is set to 0 throughout + for i in range(st.c.length): + st.c._sent[i].sent_start = False + Py_INCREF(st) + return st + + cdef class TransitionSystem: def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): self.mem = Pool() self.strings = string_table self.n_moves = 0 self._size = 100 - + self.c = self.mem.alloc(self._size, sizeof(Transition)) - + for action, label_strs in sorted(labels_by_action.items()): for label_str in sorted(label_strs): self.add_action(int(action), label_str) - self.root_label = self.strings['ROOT'] self.freqs = {} if _freqs is None else _freqs for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): @@ -37,6 +46,7 @@ cdef class TransitionSystem: for i in range(10024): self.freqs[HEAD][i] = 1 self.freqs[HEAD][-i] = 1 + self.init_beam_state = _init_state def __reduce__(self): labels_by_action = {} @@ -96,7 +106,7 @@ cdef class TransitionSystem: if self.n_moves >= self._size: self._size *= 2 self.c = self.mem.realloc(self.c, self._size * sizeof(self.c[0])) - + self.c[self.n_moves] = self.init_transition(self.n_moves, action, label) self.n_moves += 1 return 1