From 99316fa631efd86a5ab5d68b11654c7366ece650 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 27 May 2017 15:50:21 -0500 Subject: [PATCH] Use ordered dict to specify actions --- spacy/syntax/arc_eager.pyx | 14 ++++++++------ spacy/syntax/ner.pyx | 31 ++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index f7c1c7922..2e424c1a9 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -9,6 +9,7 @@ import ctypes from libc.stdint cimport uint32_t from libc.string cimport memcpy from cymem.cymem cimport Pool +from collections import OrderedDict from .stateclass cimport StateClass from ._state cimport StateC, is_space_token @@ -312,12 +313,13 @@ cdef class ArcEager(TransitionSystem): @classmethod def get_actions(cls, **kwargs): actions = kwargs.get('actions', - { - SHIFT: [''], - REDUCE: [''], - RIGHT: [], - LEFT: [], - BREAK: ['ROOT']}) + OrderedDict(( + (SHIFT, ['']), + (REDUCE, ['']), + (RIGHT, []), + (LEFT, []), + (BREAK, ['ROOT']) + ))) seen_actions = set() for label in kwargs.get('left_labels', []): if label.upper() != 'ROOT': diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index af42eded4..f8db0a433 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -2,6 +2,7 @@ from __future__ import unicode_literals from thinc.typedefs cimport weight_t +from collections import OrderedDict from .stateclass cimport StateClass from ._state cimport StateC @@ -51,17 +52,29 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil: cdef class BiluoPushDown(TransitionSystem): + def __init__(self, *args, **kwargs): + TransitionSystem.__init__(self, *args, **kwargs) + + def __reduce__(self): + labels_by_action = OrderedDict() + cdef Transition t + for trans in self.c[:self.n_moves]: + label_str = self.strings[trans.label] + labels_by_action.setdefault(trans.move, []).append(label_str) + return (BiluoPushDown, (self.strings, labels_by_action), + None, None) + @classmethod def get_actions(cls, **kwargs): actions = kwargs.get('actions', - { - MISSING: [''], - BEGIN: [], - IN: [], - LAST: [], - UNIT: [], - OUT: [''] - }) + OrderedDict(( + (MISSING, ['']), + (BEGIN, []), + (IN, []), + (LAST, []), + (UNIT, []), + (OUT, ['']) + ))) seen_entities = set() for entity_type in kwargs.get('entity_types', []): if entity_type in seen_entities: @@ -90,7 +103,7 @@ cdef class BiluoPushDown(TransitionSystem): def move_name(self, int move, int label): if move == OUT: return 'O' - elif move == 'MISSING': + elif move == MISSING: return 'M' else: return MOVE_NAMES[move] + '-' + self.strings[label]