mirror of https://github.com/explosion/spaCy.git
* Bug fixes to NER
This commit is contained in:
parent
d7b2843643
commit
af9ed18cf1
|
@ -9,6 +9,7 @@ cdef struct Entity:
|
|||
|
||||
|
||||
cdef struct State:
|
||||
Entity curr
|
||||
Entity* ents
|
||||
int* tags
|
||||
int i
|
||||
|
|
|
@ -2,13 +2,16 @@ from .moves cimport BEGIN, UNIT
|
|||
|
||||
|
||||
cdef int begin_entity(State* s, label) except -1:
|
||||
s.j += 1
|
||||
s.ents[s.j].start = s.i
|
||||
s.ents[s.j].label = label
|
||||
s.curr.start = s.i
|
||||
s.curr.label = label
|
||||
|
||||
|
||||
cdef int end_entity(State* s) except -1:
|
||||
s.ents[s.j].end = s.i + 1
|
||||
s.curr.end = s.i + 1
|
||||
s.curr[s.j] = s.curr
|
||||
s.curr.start = 0
|
||||
s.curr.label = -1
|
||||
s.curr.end = 0
|
||||
|
||||
|
||||
cdef State* init_state(Pool mem, int sent_length) except NULL:
|
||||
|
@ -17,24 +20,24 @@ cdef State* init_state(Pool mem, int sent_length) except NULL:
|
|||
s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity))
|
||||
for i in range(sent_length):
|
||||
s.ents[i].label = -1
|
||||
s.curr.label = -1
|
||||
s.tags = <int*>mem.alloc(sent_length, sizeof(int))
|
||||
s.length = sent_length
|
||||
return s
|
||||
|
||||
|
||||
cdef bint entity_is_open(State *s) except -1:
|
||||
return s.j >= 0 and s.ents[s.j].label != -1
|
||||
return s.curr.label != -1
|
||||
|
||||
|
||||
cdef bint entity_is_sunk(State *s, Move* golds) except -1:
|
||||
if not entity_is_open(s):
|
||||
return False
|
||||
|
||||
cdef Entity* ent = &s.ents[s.j]
|
||||
cdef Move* gold = &golds[ent.start]
|
||||
cdef Move* gold = &golds[s.curr.start]
|
||||
if gold.action != BEGIN and gold.action != UNIT:
|
||||
return True
|
||||
elif gold.label != ent.label:
|
||||
elif gold.label != s.curr.label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport begin_entity
|
||||
from ._state cimport end_entity
|
||||
from ._state cimport entity_is_open
|
||||
from ._state cimport entity_is_sunk
|
||||
|
||||
|
||||
ACTION_NAMES = ['' for _ in range(N_ACTIONS)]
|
||||
ACTION_NAMES[<int>BEGIN] = 'B'
|
||||
ACTION_NAMES[<int>IN] = 'I'
|
||||
|
@ -16,11 +19,11 @@ cdef bint can_begin(State* s, int label):
|
|||
|
||||
|
||||
cdef bint can_in(State* s, int label):
|
||||
return entity_is_open(s) and s.ents[s.j].tag == label
|
||||
return entity_is_open(s) and s.ents[s.j].label == label
|
||||
|
||||
|
||||
cdef bint can_last(State* s, int label):
|
||||
return entity_is_open(s) and s.ents[s.j].tag == label
|
||||
return entity_is_open(s) and s.ents[s.j].label == label
|
||||
|
||||
|
||||
cdef bint can_unit(State* s, int label):
|
||||
|
@ -119,6 +122,7 @@ cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0:
|
|||
elif m.action == OUT:
|
||||
m.accept = can_out(s, m.label)
|
||||
n_accept += m.accept
|
||||
assert n_accept != 0
|
||||
return n_accept
|
||||
|
||||
|
||||
|
@ -133,6 +137,7 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s)
|
|||
m.accept = is_oracle(<ActionType>m.action, m.label, <ActionType>g.action,
|
||||
g.label, next_act, is_sunk)
|
||||
n_accept += m.accept
|
||||
assert n_accept != 0
|
||||
return n_accept
|
||||
|
||||
|
||||
|
@ -182,6 +187,7 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
|
|||
for label in range(n_tags):
|
||||
moves[i].action = IN
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
for label in range(n_tags):
|
||||
moves[i].action = LAST
|
||||
moves[i].label = label
|
||||
|
@ -190,4 +196,5 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
|
|||
moves[i].action = UNIT
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
moves[i].label == OUT
|
||||
moves[i].action = OUT
|
||||
moves[i].label = 0
|
||||
|
|
|
@ -12,3 +12,5 @@ cdef class PyState:
|
|||
|
||||
cdef Move* _moves
|
||||
cdef State* _s
|
||||
|
||||
cdef Move* _get_move(self, unicode move_name) except NULL
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport init_state
|
||||
from ._state cimport entity_is_open
|
||||
from .moves cimport fill_moves
|
||||
from .moves cimport transition
|
||||
from .moves cimport set_accept_if_valid
|
||||
from .moves import get_n_moves
|
||||
from .moves import ACTION_NAMES
|
||||
|
||||
|
@ -19,16 +22,23 @@ cdef class PyState:
|
|||
for i in range(self.n_classes):
|
||||
m = &self._moves[i]
|
||||
action_name = ACTION_NAMES[m.action]
|
||||
tag_name = tag_names[m.label]
|
||||
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
|
||||
if action_name == 'O':
|
||||
self.moves_by_name['O'] = i
|
||||
else:
|
||||
tag_name = tag_names[m.label]
|
||||
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
|
||||
|
||||
cdef Move* _get_move(self, unicode move_name) except NULL:
|
||||
return &self._moves[self.moves_by_name[move_name]]
|
||||
|
||||
def transition(self, unicode move_name):
|
||||
cdef int m_i = self.moves_by_name[move_name]
|
||||
cdef Move* m = &self._moves[m_i]
|
||||
cdef Move* m = self._get_move(move_name)
|
||||
transition(self._s, m)
|
||||
|
||||
def is_valid(self, unicode move_name):
|
||||
pass
|
||||
cdef Move* m = self._get_move(move_name)
|
||||
set_accept_if_valid(self._moves, self.n_classes, self._s)
|
||||
return m.accept
|
||||
|
||||
def is_gold(self, unicode move_name):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue