From af9ed18cf1f377cd662243b1f2bad491c7700889 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 10 Nov 2014 17:39:23 +1100 Subject: [PATCH] * Bug fixes to NER --- spacy/ner/_state.pxd | 1 + spacy/ner/_state.pyx | 19 +++++++++++-------- spacy/ner/moves.pyx | 13 ++++++++++--- spacy/ner/pystate.pxd | 2 ++ spacy/ner/pystate.pyx | 20 +++++++++++++++----- 5 files changed, 39 insertions(+), 16 deletions(-) diff --git a/spacy/ner/_state.pxd b/spacy/ner/_state.pxd index c522e748b..da89a9493 100644 --- a/spacy/ner/_state.pxd +++ b/spacy/ner/_state.pxd @@ -9,6 +9,7 @@ cdef struct Entity: cdef struct State: + Entity curr Entity* ents int* tags int i diff --git a/spacy/ner/_state.pyx b/spacy/ner/_state.pyx index dce8e4d45..ae1300e2f 100644 --- a/spacy/ner/_state.pyx +++ b/spacy/ner/_state.pyx @@ -2,13 +2,16 @@ from .moves cimport BEGIN, UNIT cdef int begin_entity(State* s, label) except -1: - s.j += 1 - s.ents[s.j].start = s.i - s.ents[s.j].label = label + s.curr.start = s.i + s.curr.label = label cdef int end_entity(State* s) except -1: - s.ents[s.j].end = s.i + 1 + s.curr.end = s.i + 1 + s.curr[s.j] = s.curr + s.curr.start = 0 + s.curr.label = -1 + s.curr.end = 0 cdef State* init_state(Pool mem, int sent_length) except NULL: @@ -17,24 +20,24 @@ cdef State* init_state(Pool mem, int sent_length) except NULL: s.ents = mem.alloc(sent_length, sizeof(Entity)) for i in range(sent_length): s.ents[i].label = -1 + s.curr.label = -1 s.tags = mem.alloc(sent_length, sizeof(int)) s.length = sent_length return s cdef bint entity_is_open(State *s) except -1: - return s.j >= 0 and s.ents[s.j].label != -1 + return s.curr.label != -1 cdef bint entity_is_sunk(State *s, Move* golds) except -1: if not entity_is_open(s): return False - cdef Entity* ent = &s.ents[s.j] - cdef Move* gold = &golds[ent.start] + cdef Move* gold = &golds[s.curr.start] if gold.action != BEGIN and gold.action != UNIT: return True - elif gold.label != ent.label: + elif gold.label != s.curr.label: return True else: return False diff --git a/spacy/ner/moves.pyx b/spacy/ner/moves.pyx index 595fc19e6..6cde4c23e 100644 --- a/spacy/ner/moves.pyx +++ b/spacy/ner/moves.pyx @@ -1,8 +1,11 @@ +from __future__ import unicode_literals + from ._state cimport begin_entity from ._state cimport end_entity from ._state cimport entity_is_open from ._state cimport entity_is_sunk + ACTION_NAMES = ['' for _ in range(N_ACTIONS)] ACTION_NAMES[BEGIN] = 'B' ACTION_NAMES[IN] = 'I' @@ -16,11 +19,11 @@ cdef bint can_begin(State* s, int label): cdef bint can_in(State* s, int label): - return entity_is_open(s) and s.ents[s.j].tag == label + return entity_is_open(s) and s.ents[s.j].label == label cdef bint can_last(State* s, int label): - return entity_is_open(s) and s.ents[s.j].tag == label + return entity_is_open(s) and s.ents[s.j].label == label cdef bint can_unit(State* s, int label): @@ -119,6 +122,7 @@ cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0: elif m.action == OUT: m.accept = can_out(s, m.label) n_accept += m.accept + assert n_accept != 0 return n_accept @@ -133,6 +137,7 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s) m.accept = is_oracle(m.action, m.label, g.action, g.label, next_act, is_sunk) n_accept += m.accept + assert n_accept != 0 return n_accept @@ -182,6 +187,7 @@ cdef int fill_moves(Move* moves, int n_tags) except -1: for label in range(n_tags): moves[i].action = IN moves[i].label = label + i += 1 for label in range(n_tags): moves[i].action = LAST moves[i].label = label @@ -190,4 +196,5 @@ cdef int fill_moves(Move* moves, int n_tags) except -1: moves[i].action = UNIT moves[i].label = label i += 1 - moves[i].label == OUT + moves[i].action = OUT + moves[i].label = 0 diff --git a/spacy/ner/pystate.pxd b/spacy/ner/pystate.pxd index cc2333f39..db5543e4d 100644 --- a/spacy/ner/pystate.pxd +++ b/spacy/ner/pystate.pxd @@ -12,3 +12,5 @@ cdef class PyState: cdef Move* _moves cdef State* _s + + cdef Move* _get_move(self, unicode move_name) except NULL diff --git a/spacy/ner/pystate.pyx b/spacy/ner/pystate.pyx index 810d5d980..e7acc29cd 100644 --- a/spacy/ner/pystate.pyx +++ b/spacy/ner/pystate.pyx @@ -1,7 +1,10 @@ +from __future__ import unicode_literals + from ._state cimport init_state from ._state cimport entity_is_open from .moves cimport fill_moves from .moves cimport transition +from .moves cimport set_accept_if_valid from .moves import get_n_moves from .moves import ACTION_NAMES @@ -19,16 +22,23 @@ cdef class PyState: for i in range(self.n_classes): m = &self._moves[i] action_name = ACTION_NAMES[m.action] - tag_name = tag_names[m.label] - self.moves_by_name['%s-%s' % (action_name, tag_name)] = i + if action_name == 'O': + self.moves_by_name['O'] = i + else: + tag_name = tag_names[m.label] + self.moves_by_name['%s-%s' % (action_name, tag_name)] = i + + cdef Move* _get_move(self, unicode move_name) except NULL: + return &self._moves[self.moves_by_name[move_name]] def transition(self, unicode move_name): - cdef int m_i = self.moves_by_name[move_name] - cdef Move* m = &self._moves[m_i] + cdef Move* m = self._get_move(move_name) transition(self._s, m) def is_valid(self, unicode move_name): - pass + cdef Move* m = self._get_move(move_name) + set_accept_if_valid(self._moves, self.n_classes, self._s) + return m.accept def is_gold(self, unicode move_name): pass