mirror of https://github.com/explosion/spaCy.git
Set Doc.tensor attribute in parser
This commit is contained in:
parent
62ed58935a
commit
a5b05f85f0
|
@ -1,6 +1,7 @@
|
|||
# cython: infer_types=True
|
||||
# cython: cdivision=True
|
||||
# cython: boundscheck=False
|
||||
# cython: profile=True
|
||||
# coding: utf-8
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
|
@ -322,15 +323,17 @@ cdef class Parser:
|
|||
beam_density = self.cfg.get('beam_density', 0.0)
|
||||
cdef Beam beam
|
||||
if beam_width == 1:
|
||||
states = self.parse_batch([doc])
|
||||
self.set_annotations([doc], states)
|
||||
states, tokvecs = self.parse_batch([doc])
|
||||
self.set_annotations([doc], states, tensors=tokvecs)
|
||||
return doc
|
||||
else:
|
||||
beam = self.beam_parse([doc],
|
||||
beam_width=beam_width, beam_density=beam_density)[0]
|
||||
beams, tokvecs = self.beam_parse([doc],
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density)
|
||||
beam = beams[0]
|
||||
output = self.moves.get_beam_annot(beam)
|
||||
state = <StateClass>beam.at(0)
|
||||
self.set_annotations([doc], [state])
|
||||
self.set_annotations([doc], [state], tensors=tokvecs)
|
||||
_cleanup(beam)
|
||||
return output
|
||||
|
||||
|
@ -356,15 +359,16 @@ cdef class Parser:
|
|||
for subbatch in cytoolz.partition_all(8, by_length):
|
||||
subbatch = list(subbatch)
|
||||
if beam_width == 1:
|
||||
parse_states = self.parse_batch(subbatch)
|
||||
parse_states, tokvecs = self.parse_batch(subbatch)
|
||||
beams = []
|
||||
else:
|
||||
beams = self.beam_parse(subbatch, beam_width=beam_width,
|
||||
beam_density=beam_density)
|
||||
beams, tokvecs = self.beam_parse(subbatch,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density)
|
||||
parse_states = []
|
||||
for beam in beams:
|
||||
parse_states.append(<StateClass>beam.at(0))
|
||||
self.set_annotations(subbatch, parse_states)
|
||||
self.set_annotations(subbatch, parse_states, tensors=tokvecs)
|
||||
yield from batch
|
||||
|
||||
def parse_batch(self, docs):
|
||||
|
@ -411,7 +415,9 @@ cdef class Parser:
|
|||
feat_weights, bias, hW, hb,
|
||||
nr_class, nr_hidden, nr_feat, nr_piece)
|
||||
PyErr_CheckSignals()
|
||||
return state_objs
|
||||
tokvecs = self.model[0].ops.unflatten(tokvecs,
|
||||
[len(doc) for doc in docs])
|
||||
return state_objs, tokvecs
|
||||
|
||||
cdef void _parseC(self, StateC* state,
|
||||
const float* feat_weights, const float* bias,
|
||||
|
@ -508,7 +514,9 @@ cdef class Parser:
|
|||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
beams.append(beam)
|
||||
return beams
|
||||
tokvecs = self.model[0].ops.unflatten(tokvecs,
|
||||
[len(doc) for doc in docs])
|
||||
return beams, tokvecs
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
|
@ -730,13 +738,17 @@ cdef class Parser:
|
|||
c_d_scores += d_scores.shape[1]
|
||||
return d_scores
|
||||
|
||||
def set_annotations(self, docs, states):
|
||||
def set_annotations(self, docs, states, tensors=None):
|
||||
cdef StateClass state
|
||||
cdef Doc doc
|
||||
for state, doc in zip(states, docs):
|
||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||
self.moves.finalize_state(state.c)
|
||||
for i in range(doc.length):
|
||||
doc.c[i] = state.c._sent[i]
|
||||
for j in range(doc.length):
|
||||
doc.c[j] = state.c._sent[j]
|
||||
if tensors is not None:
|
||||
print(doc.tensor.shape)
|
||||
|
||||
doc.extend_tensor(tensors[i])
|
||||
self.moves.finalize_doc(doc)
|
||||
for hook in self.postprocesses:
|
||||
for doc in docs:
|
||||
|
|
Loading…
Reference in New Issue