diff --git a/spacy/ner/pystate.pxd b/spacy/ner/pystate.pxd index db5543e4d..2b5b4cdbe 100644 --- a/spacy/ner/pystate.pxd +++ b/spacy/ner/pystate.pxd @@ -11,6 +11,7 @@ cdef class PyState: cdef readonly dict moves_by_name cdef Move* _moves + cdef Move* _golds cdef State* _s cdef Move* _get_move(self, unicode move_name) except NULL diff --git a/spacy/ner/pystate.pyx b/spacy/ner/pystate.pyx index e7acc29cd..b66219a90 100644 --- a/spacy/ner/pystate.pyx +++ b/spacy/ner/pystate.pyx @@ -4,7 +4,7 @@ from ._state cimport init_state from ._state cimport entity_is_open from .moves cimport fill_moves from .moves cimport transition -from .moves cimport set_accept_if_valid +from .moves cimport set_accept_if_valid, set_accept_if_oracle from .moves import get_n_moves from .moves import ACTION_NAMES @@ -27,10 +27,18 @@ cdef class PyState: else: tag_name = tag_names[m.label] self.moves_by_name['%s-%s' % (action_name, tag_name)] = i + # TODO + self._golds = self.mem.alloc(n_tokens, sizeof(Move)) cdef Move* _get_move(self, unicode move_name) except NULL: return &self._moves[self.moves_by_name[move_name]] + def set_golds(self, list gold_names): + cdef Move* m + for i, name in enumerate(gold_names): + m = self._get_move(name) + self._golds[i] = m[0] + def transition(self, unicode move_name): cdef Move* m = self._get_move(move_name) transition(self._s, m) @@ -41,15 +49,17 @@ cdef class PyState: return m.accept def is_gold(self, unicode move_name): - pass + set_accept_if_oracle(self._moves, self._golds, self.n_classes, self._s) + cdef Move* m = self._get_move(move_name) + return m.accept property ent: def __get__(self): - return self._s.ents[self._s.j] + return self._s.curr property n_ents: def __get__(self): - return self._s.j + 1 + return self._s.j property i: def __get__(self): @@ -58,5 +68,3 @@ cdef class PyState: property open_entity: def __get__(self): return entity_is_open(self._s) - -