Use ordered dict to specify actions

This commit is contained in:
Matthew Honnibal 2017-05-27 15:50:21 -05:00
parent 655ca58c16
commit 99316fa631
2 changed files with 30 additions and 15 deletions

View File

@ -9,6 +9,7 @@ import ctypes
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from cymem.cymem cimport Pool
from collections import OrderedDict
from .stateclass cimport StateClass
from ._state cimport StateC, is_space_token
@ -312,12 +313,13 @@ cdef class ArcEager(TransitionSystem):
@classmethod
def get_actions(cls, **kwargs):
actions = kwargs.get('actions',
{
SHIFT: [''],
REDUCE: [''],
RIGHT: [],
LEFT: [],
BREAK: ['ROOT']})
OrderedDict((
(SHIFT, ['']),
(REDUCE, ['']),
(RIGHT, []),
(LEFT, []),
(BREAK, ['ROOT'])
)))
seen_actions = set()
for label in kwargs.get('left_labels', []):
if label.upper() != 'ROOT':

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
from thinc.typedefs cimport weight_t
from collections import OrderedDict
from .stateclass cimport StateClass
from ._state cimport StateC
@ -51,17 +52,29 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
cdef class BiluoPushDown(TransitionSystem):
def __init__(self, *args, **kwargs):
TransitionSystem.__init__(self, *args, **kwargs)
def __reduce__(self):
labels_by_action = OrderedDict()
cdef Transition t
for trans in self.c[:self.n_moves]:
label_str = self.strings[trans.label]
labels_by_action.setdefault(trans.move, []).append(label_str)
return (BiluoPushDown, (self.strings, labels_by_action),
None, None)
@classmethod
def get_actions(cls, **kwargs):
actions = kwargs.get('actions',
{
MISSING: [''],
BEGIN: [],
IN: [],
LAST: [],
UNIT: [],
OUT: ['']
})
OrderedDict((
(MISSING, ['']),
(BEGIN, []),
(IN, []),
(LAST, []),
(UNIT, []),
(OUT, [''])
)))
seen_entities = set()
for entity_type in kwargs.get('entity_types', []):
if entity_type in seen_entities:
@ -90,7 +103,7 @@ cdef class BiluoPushDown(TransitionSystem):
def move_name(self, int move, int label):
if move == OUT:
return 'O'
elif move == 'MISSING':
elif move == MISSING:
return 'M'
else:
return MOVE_NAMES[move] + '-' + self.strings[label]