mirror of https://github.com/explosion/spaCy.git
* Fix NER oracle
This commit is contained in:
parent
c04e6ebca6
commit
0114e7600d
|
@ -1,6 +1,7 @@
|
|||
from .transition_system cimport TransitionSystem
|
||||
from .transition_system cimport Transition
|
||||
from ._state cimport State
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
cdef class BiluoPushDown(TransitionSystem):
|
||||
|
|
|
@ -186,8 +186,13 @@ cdef class Begin:
|
|||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Begin.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
if g_act == BEGIN:
|
||||
# B, Gold B --> Label match
|
||||
return label != g_tag
|
||||
|
@ -211,12 +216,17 @@ cdef class In:
|
|||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not In.is_valid(s, label):
|
||||
return 9000
|
||||
move = IN
|
||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||
|
||||
if g_act == BEGIN:
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
elif g_act == BEGIN:
|
||||
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
|
||||
return 0
|
||||
elif g_act == IN:
|
||||
|
@ -231,6 +241,8 @@ cdef class In:
|
|||
elif g_act == UNIT:
|
||||
# I, Gold U --> True iff next tag == O
|
||||
return next_act != OUT
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
|
||||
|
@ -248,10 +260,16 @@ cdef class Last:
|
|||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Last.is_valid(s, label):
|
||||
return 9000
|
||||
move = LAST
|
||||
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
if g_act == BEGIN:
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
elif g_act == BEGIN:
|
||||
# L, Gold B --> True
|
||||
return 0
|
||||
elif g_act == IN:
|
||||
|
@ -266,6 +284,8 @@ cdef class Last:
|
|||
elif g_act == UNIT:
|
||||
# L, Gold U --> True
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
cdef class Unit:
|
||||
|
@ -286,10 +306,14 @@ cdef class Unit:
|
|||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Unit.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
if g_act == UNIT:
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
elif g_act == UNIT:
|
||||
# U, Gold U --> True iff tag match
|
||||
return label != g_tag
|
||||
else:
|
||||
|
@ -312,10 +336,16 @@ cdef class Out:
|
|||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Out.is_valid(s, label):
|
||||
return 9000
|
||||
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
if g_act == BEGIN:
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
elif g_act == BEGIN:
|
||||
# O, Gold B --> False
|
||||
return 1
|
||||
elif g_act == IN:
|
||||
|
@ -330,6 +360,93 @@ cdef class Out:
|
|||
elif g_act == UNIT:
|
||||
# O, Gold U --> False
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Move this logic into the cost functions
|
||||
cdef int _get_cost(int move, int label, const State* s, const GoldParseC* gold) except -1:
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef bint is_gold = _is_gold(move, label, gold.ner[s.i].move,
|
||||
gold.ner[s.i].label, next_act, is_sunk)
|
||||
return not is_gold
|
||||
|
||||
|
||||
cdef bint _is_gold(int act, int tag, int g_act, int g_tag,
|
||||
int next_act, bint is_sunk):
|
||||
if g_act == MISSING:
|
||||
return True
|
||||
if act == BEGIN:
|
||||
if g_act == BEGIN:
|
||||
# B, Gold B --> Label match
|
||||
return tag == g_tag
|
||||
else:
|
||||
# B, Gold I --> False (P)
|
||||
# B, Gold L --> False (P)
|
||||
# B, Gold O --> False (P)
|
||||
# B, Gold U --> False (P)
|
||||
return False
|
||||
elif act == IN:
|
||||
if g_act == BEGIN:
|
||||
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
|
||||
return True
|
||||
elif g_act == IN:
|
||||
# I, Gold I --> True (label forced by prev, if mismatch, P and R both sunk)
|
||||
return True
|
||||
elif g_act == LAST:
|
||||
# I, Gold L --> True iff this entity sunk and next tag == O
|
||||
return is_sunk and (next_act == OUT or next_act == MISSING)
|
||||
elif g_act == OUT:
|
||||
# I, Gold O --> True iff next tag == O
|
||||
return next_act == OUT or next_act == MISSING
|
||||
elif g_act == UNIT:
|
||||
# I, Gold U --> True iff next tag == O
|
||||
return next_act == OUT
|
||||
elif act == LAST:
|
||||
if g_act == BEGIN:
|
||||
# L, Gold B --> True
|
||||
return True
|
||||
elif g_act == IN:
|
||||
# L, Gold I --> True iff this entity sunk
|
||||
return is_sunk
|
||||
elif g_act == LAST:
|
||||
# L, Gold L --> True
|
||||
return True
|
||||
elif g_act == OUT:
|
||||
# L, Gold O --> True
|
||||
return True
|
||||
elif g_act == UNIT:
|
||||
# L, Gold U --> True
|
||||
return True
|
||||
elif act == OUT:
|
||||
if g_act == BEGIN:
|
||||
# O, Gold B --> False
|
||||
return False
|
||||
elif g_act == IN:
|
||||
# O, Gold I --> True
|
||||
return True
|
||||
elif g_act == LAST:
|
||||
# O, Gold L --> True
|
||||
return True
|
||||
elif g_act == OUT:
|
||||
# O, Gold O --> True
|
||||
return True
|
||||
elif g_act == UNIT:
|
||||
# O, Gold U --> False
|
||||
return False
|
||||
elif act == UNIT:
|
||||
if g_act == UNIT:
|
||||
# U, Gold U --> True iff tag match
|
||||
return tag == g_tag
|
||||
else:
|
||||
# U, Gold B --> False
|
||||
# U, Gold I --> False
|
||||
# U, Gold L --> False
|
||||
# U, Gold O --> False
|
||||
return False
|
||||
"""
|
||||
|
||||
|
||||
class OracleError(Exception):
|
||||
|
|
Loading…
Reference in New Issue