mirror of https://github.com/explosion/spaCy.git
Add labels implicitly for parser and ner
This commit is contained in:
parent
68b1c2984d
commit
1d20e21f3e
|
@ -614,10 +614,22 @@ cdef class ArcEager(TransitionSystem):
|
|||
actions[LEFT].setdefault('dep', 0)
|
||||
return actions
|
||||
|
||||
@property
|
||||
def builtin_labels(self):
|
||||
return ["ROOT", "dep"]
|
||||
|
||||
@property
|
||||
def action_types(self):
|
||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||
|
||||
def get_doc_labels(self, doc):
|
||||
"""Get the labels required for a given Doc."""
|
||||
labels = set(self.builtin_labels)
|
||||
for token in doc:
|
||||
if token.dep_:
|
||||
labels.add(token.dep_)
|
||||
return labels
|
||||
|
||||
def transition(self, StateClass state, action):
|
||||
cdef Transition t = self.lookup_transition(action)
|
||||
t.do(state.c, t.label)
|
||||
|
|
|
@ -126,6 +126,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
def action_types(self):
|
||||
return (BEGIN, IN, LAST, UNIT, OUT)
|
||||
|
||||
def get_doc_labels(self, doc):
|
||||
labels = set()
|
||||
for token in doc:
|
||||
if token.ent_type:
|
||||
labels.add(token.ent_type_)
|
||||
return labels
|
||||
|
||||
def move_name(self, int move, attr_t label):
|
||||
if move == OUT:
|
||||
return 'O'
|
||||
|
|
|
@ -132,6 +132,23 @@ cdef class Parser(TrainablePipe):
|
|||
return 1
|
||||
return 0
|
||||
|
||||
def _ensure_labels_are_added(self, docs):
|
||||
"""Ensure that all labels for a batch of docs are added."""
|
||||
resized = False
|
||||
labels = set()
|
||||
for doc in docs:
|
||||
labels.update(self.moves.get_doc_labels(doc))
|
||||
for label in labels:
|
||||
for action in self.moves.action_types:
|
||||
added = self.moves.add_action(action, label)
|
||||
if added:
|
||||
self.vocab.strings.add(label)
|
||||
resized = True
|
||||
if resized:
|
||||
self._resize()
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _resize(self):
|
||||
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
||||
if self._rehearsal_model not in (True, False, None):
|
||||
|
@ -188,9 +205,9 @@ cdef class Parser(TrainablePipe):
|
|||
def predict(self, docs):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
self._ensure_labels_are_added(docs)
|
||||
if not any(len(doc) for doc in docs):
|
||||
result = self.moves.init_batch(docs)
|
||||
self._resize()
|
||||
return result
|
||||
if self.cfg["beam_width"] == 1:
|
||||
return self.greedy_parse(docs, drop=0.0)
|
||||
|
@ -207,10 +224,6 @@ cdef class Parser(TrainablePipe):
|
|||
cdef StateClass state
|
||||
set_dropout_rate(self.model, drop)
|
||||
batch = self.moves.init_batch(docs)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
self._resize()
|
||||
model = self.model.predict(docs)
|
||||
weights = get_c_weights(model)
|
||||
for state in batch:
|
||||
|
@ -234,10 +247,6 @@ cdef class Parser(TrainablePipe):
|
|||
beam_width,
|
||||
density=beam_density
|
||||
)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
self._resize()
|
||||
model = self.model.predict(docs)
|
||||
while not batch.is_done:
|
||||
states = batch.get_unfinished_states()
|
||||
|
@ -314,6 +323,9 @@ cdef class Parser(TrainablePipe):
|
|||
losses = {}
|
||||
losses.setdefault(self.name, 0.)
|
||||
validate_examples(examples, "Parser.update")
|
||||
self._ensure_labels_are_added(
|
||||
[eg.x for eg in examples] + [eg.y for eg in examples]
|
||||
)
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(examples, drop=drop, sgd=sgd)
|
||||
|
||||
|
|
Loading…
Reference in New Issue