diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index e5a7406f6..2c1157216 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -42,8 +42,8 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): - move_labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {}, - LEFT: {}, BREAK: {'ROOT': True}} + move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, + LEFT: {}, BREAK: {'': True}} for raw_text, segmented, (ids, tags, heads, labels, iob) in gold_parses: for i, (head, label) in enumerate(zip(heads, labels)): if label != 'ROOT': diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index c4ead8258..38c211781 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -70,8 +70,8 @@ cdef int _is_valid(int act, int label, const State* s) except -1: cdef class BiluoPushDown(TransitionSystem): @classmethod def get_labels(cls, gold_tuples): - move_labels = {MISSING: {'ROOT': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, - OUT: {'ROOT': True}} + move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, + OUT: {'': True}} moves = ('M', 'B', 'I', 'L', 'U') for (raw_text, toks, (ids, tags, heads, labels, biluo)) in gold_tuples: for i, ner_tag in enumerate(biluo): @@ -99,7 +99,7 @@ cdef class BiluoPushDown(TransitionSystem): label = 0 elif '-' in name: move_str, label_str = name.split('-', 1) - label = self.label_ids[label_str] + label = self.strings[label_str] else: move_str = name label = 0 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 72e9cedf8..820b61426 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -21,7 +21,7 @@ cdef class TransitionSystem: self.strings = string_table for action, label_strs in sorted(labels_by_action.items()): for label_str in sorted(label_strs): - label_id = self.strings[unicode(label_str)] + label_id = self.strings[unicode(label_str)] if label_str else 0 moves[i] = self.init_transition(i, int(action), label_id) i += 1 self.c = moves