mirror of https://github.com/explosion/spaCy.git
Use ordered dict to specify actions
This commit is contained in:
parent
655ca58c16
commit
99316fa631
|
@ -9,6 +9,7 @@ import ctypes
|
|||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
from cymem.cymem cimport Pool
|
||||
from collections import OrderedDict
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC, is_space_token
|
||||
|
@ -312,12 +313,13 @@ cdef class ArcEager(TransitionSystem):
|
|||
@classmethod
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions',
|
||||
{
|
||||
SHIFT: [''],
|
||||
REDUCE: [''],
|
||||
RIGHT: [],
|
||||
LEFT: [],
|
||||
BREAK: ['ROOT']})
|
||||
OrderedDict((
|
||||
(SHIFT, ['']),
|
||||
(REDUCE, ['']),
|
||||
(RIGHT, []),
|
||||
(LEFT, []),
|
||||
(BREAK, ['ROOT'])
|
||||
)))
|
||||
seen_actions = set()
|
||||
for label in kwargs.get('left_labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from thinc.typedefs cimport weight_t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
@ -51,17 +52,29 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
|||
|
||||
|
||||
cdef class BiluoPushDown(TransitionSystem):
|
||||
def __init__(self, *args, **kwargs):
|
||||
TransitionSystem.__init__(self, *args, **kwargs)
|
||||
|
||||
def __reduce__(self):
|
||||
labels_by_action = OrderedDict()
|
||||
cdef Transition t
|
||||
for trans in self.c[:self.n_moves]:
|
||||
label_str = self.strings[trans.label]
|
||||
labels_by_action.setdefault(trans.move, []).append(label_str)
|
||||
return (BiluoPushDown, (self.strings, labels_by_action),
|
||||
None, None)
|
||||
|
||||
@classmethod
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions',
|
||||
{
|
||||
MISSING: [''],
|
||||
BEGIN: [],
|
||||
IN: [],
|
||||
LAST: [],
|
||||
UNIT: [],
|
||||
OUT: ['']
|
||||
})
|
||||
OrderedDict((
|
||||
(MISSING, ['']),
|
||||
(BEGIN, []),
|
||||
(IN, []),
|
||||
(LAST, []),
|
||||
(UNIT, []),
|
||||
(OUT, [''])
|
||||
)))
|
||||
seen_entities = set()
|
||||
for entity_type in kwargs.get('entity_types', []):
|
||||
if entity_type in seen_entities:
|
||||
|
@ -90,7 +103,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
def move_name(self, int move, int label):
|
||||
if move == OUT:
|
||||
return 'O'
|
||||
elif move == 'MISSING':
|
||||
elif move == MISSING:
|
||||
return 'M'
|
||||
else:
|
||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||
|
|
Loading…
Reference in New Issue