diff --git a/spacy/syntax/ner.pxd b/spacy/syntax/ner.pxd index 00ac8f6f8..3687bbb27 100644 --- a/spacy/syntax/ner.pxd +++ b/spacy/syntax/ner.pxd @@ -1,6 +1,7 @@ from .transition_system cimport TransitionSystem from .transition_system cimport Transition from ._state cimport State +from ..gold cimport GoldParseC cdef class BiluoPushDown(TransitionSystem): diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 1787aaf27..8e9dcffe4 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -186,8 +186,13 @@ cdef class Begin: @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: + if not Begin.is_valid(s, label): + return 9000 cdef int g_act = gold.ner[s.i].move cdef int g_tag = gold.ner[s.i].label + + if g_act == MISSING: + return 0 if g_act == BEGIN: # B, Gold B --> Label match return label != g_tag @@ -211,12 +216,17 @@ cdef class In: @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: + if not In.is_valid(s, label): + return 9000 + move = IN cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT cdef int g_act = gold.ner[s.i].move cdef int g_tag = gold.ner[s.i].label cdef bint is_sunk = _entity_is_sunk(s, gold.ner) - - if g_act == BEGIN: + + if g_act == MISSING: + return 0 + elif g_act == BEGIN: # I, Gold B --> True (P of bad open entity sunk, R of this entity sunk) return 0 elif g_act == IN: @@ -231,6 +241,8 @@ cdef class In: elif g_act == UNIT: # I, Gold U --> True iff next tag == O return next_act != OUT + else: + return 1 @@ -248,10 +260,16 @@ cdef class Last: @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: + if not Last.is_valid(s, label): + return 9000 + move = LAST + cdef int g_act = gold.ner[s.i].move cdef int g_tag = gold.ner[s.i].label - - if g_act == BEGIN: + + if g_act == MISSING: + return 0 + elif g_act == BEGIN: # L, Gold B --> True return 0 elif g_act == IN: @@ -266,6 +284,8 @@ cdef class Last: elif g_act == UNIT: # L, Gold U --> True return 0 + else: + return 1 cdef class Unit: @@ -286,10 +306,14 @@ cdef class Unit: @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: + if not Unit.is_valid(s, label): + return 9000 cdef int g_act = gold.ner[s.i].move cdef int g_tag = gold.ner[s.i].label - if g_act == UNIT: + if g_act == MISSING: + return 0 + elif g_act == UNIT: # U, Gold U --> True iff tag match return label != g_tag else: @@ -312,10 +336,16 @@ cdef class Out: @staticmethod cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: + if not Out.is_valid(s, label): + return 9000 + cdef int g_act = gold.ner[s.i].move cdef int g_tag = gold.ner[s.i].label - if g_act == BEGIN: + + if g_act == MISSING: + return 0 + elif g_act == BEGIN: # O, Gold B --> False return 1 elif g_act == IN: @@ -330,6 +360,93 @@ cdef class Out: elif g_act == UNIT: # O, Gold U --> False return 1 + else: + return 1 + +""" + +# TODO: Move this logic into the cost functions +cdef int _get_cost(int move, int label, const State* s, const GoldParseC* gold) except -1: + cdef bint is_sunk = _entity_is_sunk(s, gold.ner) + cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT + cdef bint is_gold = _is_gold(move, label, gold.ner[s.i].move, + gold.ner[s.i].label, next_act, is_sunk) + return not is_gold + + +cdef bint _is_gold(int act, int tag, int g_act, int g_tag, + int next_act, bint is_sunk): + if g_act == MISSING: + return True + if act == BEGIN: + if g_act == BEGIN: + # B, Gold B --> Label match + return tag == g_tag + else: + # B, Gold I --> False (P) + # B, Gold L --> False (P) + # B, Gold O --> False (P) + # B, Gold U --> False (P) + return False + elif act == IN: + if g_act == BEGIN: + # I, Gold B --> True (P of bad open entity sunk, R of this entity sunk) + return True + elif g_act == IN: + # I, Gold I --> True (label forced by prev, if mismatch, P and R both sunk) + return True + elif g_act == LAST: + # I, Gold L --> True iff this entity sunk and next tag == O + return is_sunk and (next_act == OUT or next_act == MISSING) + elif g_act == OUT: + # I, Gold O --> True iff next tag == O + return next_act == OUT or next_act == MISSING + elif g_act == UNIT: + # I, Gold U --> True iff next tag == O + return next_act == OUT + elif act == LAST: + if g_act == BEGIN: + # L, Gold B --> True + return True + elif g_act == IN: + # L, Gold I --> True iff this entity sunk + return is_sunk + elif g_act == LAST: + # L, Gold L --> True + return True + elif g_act == OUT: + # L, Gold O --> True + return True + elif g_act == UNIT: + # L, Gold U --> True + return True + elif act == OUT: + if g_act == BEGIN: + # O, Gold B --> False + return False + elif g_act == IN: + # O, Gold I --> True + return True + elif g_act == LAST: + # O, Gold L --> True + return True + elif g_act == OUT: + # O, Gold O --> True + return True + elif g_act == UNIT: + # O, Gold U --> False + return False + elif act == UNIT: + if g_act == UNIT: + # U, Gold U --> True iff tag match + return tag == g_tag + else: + # U, Gold B --> False + # U, Gold I --> False + # U, Gold L --> False + # U, Gold O --> False + return False +""" class OracleError(Exception):