mirror of https://github.com/explosion/spaCy.git
* Bug fixes to NER
This commit is contained in:
parent
d7b2843643
commit
af9ed18cf1
|
@ -9,6 +9,7 @@ cdef struct Entity:
|
||||||
|
|
||||||
|
|
||||||
cdef struct State:
|
cdef struct State:
|
||||||
|
Entity curr
|
||||||
Entity* ents
|
Entity* ents
|
||||||
int* tags
|
int* tags
|
||||||
int i
|
int i
|
||||||
|
|
|
@ -2,13 +2,16 @@ from .moves cimport BEGIN, UNIT
|
||||||
|
|
||||||
|
|
||||||
cdef int begin_entity(State* s, label) except -1:
|
cdef int begin_entity(State* s, label) except -1:
|
||||||
s.j += 1
|
s.curr.start = s.i
|
||||||
s.ents[s.j].start = s.i
|
s.curr.label = label
|
||||||
s.ents[s.j].label = label
|
|
||||||
|
|
||||||
|
|
||||||
cdef int end_entity(State* s) except -1:
|
cdef int end_entity(State* s) except -1:
|
||||||
s.ents[s.j].end = s.i + 1
|
s.curr.end = s.i + 1
|
||||||
|
s.curr[s.j] = s.curr
|
||||||
|
s.curr.start = 0
|
||||||
|
s.curr.label = -1
|
||||||
|
s.curr.end = 0
|
||||||
|
|
||||||
|
|
||||||
cdef State* init_state(Pool mem, int sent_length) except NULL:
|
cdef State* init_state(Pool mem, int sent_length) except NULL:
|
||||||
|
@ -17,24 +20,24 @@ cdef State* init_state(Pool mem, int sent_length) except NULL:
|
||||||
s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity))
|
s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity))
|
||||||
for i in range(sent_length):
|
for i in range(sent_length):
|
||||||
s.ents[i].label = -1
|
s.ents[i].label = -1
|
||||||
|
s.curr.label = -1
|
||||||
s.tags = <int*>mem.alloc(sent_length, sizeof(int))
|
s.tags = <int*>mem.alloc(sent_length, sizeof(int))
|
||||||
s.length = sent_length
|
s.length = sent_length
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
cdef bint entity_is_open(State *s) except -1:
|
cdef bint entity_is_open(State *s) except -1:
|
||||||
return s.j >= 0 and s.ents[s.j].label != -1
|
return s.curr.label != -1
|
||||||
|
|
||||||
|
|
||||||
cdef bint entity_is_sunk(State *s, Move* golds) except -1:
|
cdef bint entity_is_sunk(State *s, Move* golds) except -1:
|
||||||
if not entity_is_open(s):
|
if not entity_is_open(s):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
cdef Entity* ent = &s.ents[s.j]
|
cdef Move* gold = &golds[s.curr.start]
|
||||||
cdef Move* gold = &golds[ent.start]
|
|
||||||
if gold.action != BEGIN and gold.action != UNIT:
|
if gold.action != BEGIN and gold.action != UNIT:
|
||||||
return True
|
return True
|
||||||
elif gold.label != ent.label:
|
elif gold.label != s.curr.label:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ._state cimport begin_entity
|
from ._state cimport begin_entity
|
||||||
from ._state cimport end_entity
|
from ._state cimport end_entity
|
||||||
from ._state cimport entity_is_open
|
from ._state cimport entity_is_open
|
||||||
from ._state cimport entity_is_sunk
|
from ._state cimport entity_is_sunk
|
||||||
|
|
||||||
|
|
||||||
ACTION_NAMES = ['' for _ in range(N_ACTIONS)]
|
ACTION_NAMES = ['' for _ in range(N_ACTIONS)]
|
||||||
ACTION_NAMES[<int>BEGIN] = 'B'
|
ACTION_NAMES[<int>BEGIN] = 'B'
|
||||||
ACTION_NAMES[<int>IN] = 'I'
|
ACTION_NAMES[<int>IN] = 'I'
|
||||||
|
@ -16,11 +19,11 @@ cdef bint can_begin(State* s, int label):
|
||||||
|
|
||||||
|
|
||||||
cdef bint can_in(State* s, int label):
|
cdef bint can_in(State* s, int label):
|
||||||
return entity_is_open(s) and s.ents[s.j].tag == label
|
return entity_is_open(s) and s.ents[s.j].label == label
|
||||||
|
|
||||||
|
|
||||||
cdef bint can_last(State* s, int label):
|
cdef bint can_last(State* s, int label):
|
||||||
return entity_is_open(s) and s.ents[s.j].tag == label
|
return entity_is_open(s) and s.ents[s.j].label == label
|
||||||
|
|
||||||
|
|
||||||
cdef bint can_unit(State* s, int label):
|
cdef bint can_unit(State* s, int label):
|
||||||
|
@ -119,6 +122,7 @@ cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0:
|
||||||
elif m.action == OUT:
|
elif m.action == OUT:
|
||||||
m.accept = can_out(s, m.label)
|
m.accept = can_out(s, m.label)
|
||||||
n_accept += m.accept
|
n_accept += m.accept
|
||||||
|
assert n_accept != 0
|
||||||
return n_accept
|
return n_accept
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +137,7 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s)
|
||||||
m.accept = is_oracle(<ActionType>m.action, m.label, <ActionType>g.action,
|
m.accept = is_oracle(<ActionType>m.action, m.label, <ActionType>g.action,
|
||||||
g.label, next_act, is_sunk)
|
g.label, next_act, is_sunk)
|
||||||
n_accept += m.accept
|
n_accept += m.accept
|
||||||
|
assert n_accept != 0
|
||||||
return n_accept
|
return n_accept
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,6 +187,7 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
|
||||||
for label in range(n_tags):
|
for label in range(n_tags):
|
||||||
moves[i].action = IN
|
moves[i].action = IN
|
||||||
moves[i].label = label
|
moves[i].label = label
|
||||||
|
i += 1
|
||||||
for label in range(n_tags):
|
for label in range(n_tags):
|
||||||
moves[i].action = LAST
|
moves[i].action = LAST
|
||||||
moves[i].label = label
|
moves[i].label = label
|
||||||
|
@ -190,4 +196,5 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
|
||||||
moves[i].action = UNIT
|
moves[i].action = UNIT
|
||||||
moves[i].label = label
|
moves[i].label = label
|
||||||
i += 1
|
i += 1
|
||||||
moves[i].label == OUT
|
moves[i].action = OUT
|
||||||
|
moves[i].label = 0
|
||||||
|
|
|
@ -12,3 +12,5 @@ cdef class PyState:
|
||||||
|
|
||||||
cdef Move* _moves
|
cdef Move* _moves
|
||||||
cdef State* _s
|
cdef State* _s
|
||||||
|
|
||||||
|
cdef Move* _get_move(self, unicode move_name) except NULL
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ._state cimport init_state
|
from ._state cimport init_state
|
||||||
from ._state cimport entity_is_open
|
from ._state cimport entity_is_open
|
||||||
from .moves cimport fill_moves
|
from .moves cimport fill_moves
|
||||||
from .moves cimport transition
|
from .moves cimport transition
|
||||||
|
from .moves cimport set_accept_if_valid
|
||||||
from .moves import get_n_moves
|
from .moves import get_n_moves
|
||||||
from .moves import ACTION_NAMES
|
from .moves import ACTION_NAMES
|
||||||
|
|
||||||
|
@ -19,16 +22,23 @@ cdef class PyState:
|
||||||
for i in range(self.n_classes):
|
for i in range(self.n_classes):
|
||||||
m = &self._moves[i]
|
m = &self._moves[i]
|
||||||
action_name = ACTION_NAMES[m.action]
|
action_name = ACTION_NAMES[m.action]
|
||||||
tag_name = tag_names[m.label]
|
if action_name == 'O':
|
||||||
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
|
self.moves_by_name['O'] = i
|
||||||
|
else:
|
||||||
|
tag_name = tag_names[m.label]
|
||||||
|
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
|
||||||
|
|
||||||
|
cdef Move* _get_move(self, unicode move_name) except NULL:
|
||||||
|
return &self._moves[self.moves_by_name[move_name]]
|
||||||
|
|
||||||
def transition(self, unicode move_name):
|
def transition(self, unicode move_name):
|
||||||
cdef int m_i = self.moves_by_name[move_name]
|
cdef Move* m = self._get_move(move_name)
|
||||||
cdef Move* m = &self._moves[m_i]
|
|
||||||
transition(self._s, m)
|
transition(self._s, m)
|
||||||
|
|
||||||
def is_valid(self, unicode move_name):
|
def is_valid(self, unicode move_name):
|
||||||
pass
|
cdef Move* m = self._get_move(move_name)
|
||||||
|
set_accept_if_valid(self._moves, self.n_classes, self._s)
|
||||||
|
return m.accept
|
||||||
|
|
||||||
def is_gold(self, unicode move_name):
|
def is_gold(self, unicode move_name):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue