diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index 07bd0cb5c..5a7d3609b 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -37,6 +37,7 @@ from ..errors import Errors, TempErrors from .. import util from .stateclass cimport StateClass from .transition_system cimport Transition +from . import _beam_utils from . import nonproj @@ -196,26 +197,6 @@ class ParserModel(Model): Model.__init__(self) self._layers = [tok2vec, lower_model, upper_model] - @property - def nO(self): - return self._layers[-1].nO - - @property - def nI(self): - return self._layers[1].nI - - @property - def nH(self): - return self._layers[1].nO - - @property - def nF(self): - return self._layers[1].nF - - @property - def nP(self): - return self._layers[1].nP - def begin_update(self, docs, drop=0.): step_model = ParserStepModel(docs, self._layers, drop=drop) def finish_parser_update(golds, sgd=None): @@ -223,6 +204,15 @@ class ParserModel(Model): return None return step_model, finish_parser_update + def resize_output(self, new_output): + # Weights are stored in (nr_out, nr_in) format, so we're basically + # just adding rows here. + smaller = self._layers[-1]._layers[-1] + larger = Affine(self.moves.n_moves, smaller.nI) + copy_array(larger.W[:smaller.nO], smaller.W) + copy_array(larger.b[:smaller.nO], smaller.b) + self._layers[-1]._layers[-1] = larger + @property def tok2vec(self): return self._layers[0] @@ -274,15 +264,15 @@ class ParserStepModel(Model): return None return scores, backprop_parser_step - def get_token_ids(self, states): - cdef StateClass state - cdef int n_tokens = self.state2vec.nF - cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), + def get_token_ids(self, batch): + states = _beam_utils.collect_states(batch) + cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF), dtype='i', order='C') c_ids = ids.data - for i, state in enumerate(states): - if not state.is_final(): - state.c.set_context_tokens(c_ids, n_tokens) + cdef StateClass state + for state in states: + if not state.c.is_final(): + state.c.set_context_tokens(c_ids, ids.shape[1]) c_ids += ids.shape[1] return ids diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index ac6025ab2..e3d83d4c1 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -43,6 +43,8 @@ from .. import util from .stateclass cimport StateClass from ._state cimport StateC from .transition_system cimport Transition +from . cimport _beam_utils +from . import _beam_utils from . import nonproj @@ -172,7 +174,7 @@ cdef class Parser: with self.model.use_params(params): yield - def __call__(self, Doc doc, beam_width=None, beam_density=None): + def __call__(self, Doc doc, beam_width=None): """Apply the parser or entity recognizer, setting the annotations onto the `Doc` object. @@ -180,14 +182,11 @@ cdef class Parser: """ if beam_width is None: beam_width = self.cfg.get('beam_width', 1) - if beam_density is None: - beam_density = self.cfg.get('beam_density', 0.0) - states = self.predict([doc]) + states = self.predict([doc], beam_width=beam_width) self.set_annotations([doc], states, tensors=None) return doc - def pipe(self, docs, int batch_size=256, int n_threads=2, - beam_width=None, beam_density=None): + def pipe(self, docs, int batch_size=256, int n_threads=2, beam_width=None): """Process a stream of documents. stream: The sequence of documents to process. @@ -198,38 +197,40 @@ cdef class Parser: """ if beam_width is None: beam_width = self.cfg.get('beam_width', 1) - if beam_density is None: - beam_density = self.cfg.get('beam_density', 0.0) cdef Doc doc for batch in cytoolz.partition_all(batch_size, docs): batch_in_order = list(batch) by_length = sorted(batch_in_order, key=lambda doc: len(doc)) for subbatch in cytoolz.partition_all(8, by_length): subbatch = list(subbatch) - parse_states = self.predict(subbatch, - beam_width=beam_width, - beam_density=beam_density) + parse_states = self.predict(subbatch, beam_width=beam_width) self.set_annotations(subbatch, parse_states, tensors=None) for doc in batch_in_order: yield doc - def predict(self, docs, beam_width=1, beam_density=0.): + def predict(self, docs, beam_width=1): if isinstance(docs, Doc): docs = [docs] cdef vector[StateC*] states cdef StateClass state - state_objs = self.moves.init_batch(docs) - for state in state_objs: - states.push_back(state.c) - # Prepare the stepwise model, and get the callback for finishing the batch model = self.model(docs) - weights = get_c_weights(model) - sizes = get_c_sizes(model, states.size()) - with nogil: - self._parseC(&states[0], - weights, sizes) - return state_objs + if beam_width == 1: + batch = self.moves.init_batch(docs) + weights = get_c_weights(model) + sizes = get_c_sizes(model, states.size()) + for state in batch: + states.push_back(state.c) + with nogil: + self._parseC(&states[0], + weights, sizes) + else: + batch = self.moves.init_beams(docs, beam_width) + unfinished = list(batch) + while unfinished: + scores = model.predict(unfinished) + unfinished = self.transition_beams(batch, scores) + return batch cdef void _parseC(self, StateC** states, WeightsC weights, SizesC sizes) nogil: @@ -250,10 +251,21 @@ cdef class Parser: states[i] = unfinished[i] sizes.states = unfinished.size() unfinished.clear() - - def set_annotations(self, docs, states, tensors=None): + + def set_annotations(self, docs, states_or_beams, tensors=None): cdef StateClass state + cdef Beam beam cdef Doc doc + states = [] + beams = [] + for state_or_beam in states_or_beams: + if isinstance(state_or_beam, StateClass): + states.append(state_or_beam) + else: + beam = state_or_beam + state = StateClass.borrow(beam.at(0)) + states.append(state) + beams.append(beam) for i, (state, doc) in enumerate(zip(states, docs)): self.moves.finalize_state(state.c) for j in range(doc.length): @@ -262,14 +274,17 @@ cdef class Parser: for hook in self.postprocesses: for doc in docs: hook(doc) + for beam in beams: + _beam_utils.cleanup_beam(beam) - def transition_batch(self, states, float[:, ::1] scores): + def transition_states(self, states, float[:, ::1] scores): cdef StateClass state cdef float* c_scores = &scores[0, 0] cdef vector[StateC*] c_states for state in states: c_states.push_back(state.c) self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0]) + return [state for state in states if not state.c.is_final()] cdef void c_transition_batch(self, StateC** states, const float* scores, int nr_class, int batch_size) nogil: @@ -282,6 +297,20 @@ cdef class Parser: action = self.moves.c[guess] action.do(states[i], action.label) states[i].push_hist(guess) + + def transition_beams(self, beams, float[:, ::1] scores): + cdef Beam beam + cdef float* c_scores = &scores[0, 0] + for beam in beams: + for i in range(beam.size): + state = beam.at(i) + if not state.is_final(): + self.moves.set_valid(beam.is_valid[i], state) + memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float)) + c_scores += scores.shape[1] + beam.advance(_beam_utils.transition_state, NULL, self.moves.c) + beam.check_done(_beam_utils.check_final_state, NULL) + return [b for b in beams if not b.is_done] def update(self, docs, golds, drop=0., sgd=None, losses=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): @@ -290,6 +319,13 @@ cdef class Parser: if len(docs) != len(golds): raise ValueError(Errors.E077.format(value='update', n_docs=len(docs), n_golds=len(golds))) + # The probability we use beam update, instead of falling back to + # a greedy update + beam_update_prob = 1-self.cfg.get('beam_update_prob', 0.5) + if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= beam_update_prob: + return self.update_beam(docs, golds, + self.cfg['beam_width'], self.cfg['beam_density'], + drop=drop, sgd=sgd, losses=losses) # Chop sequences into lengths of this many transitions, to make the # batch uniform length. cut_gold = numpy.random.choice(range(20, 100)) @@ -307,11 +343,36 @@ cdef class Parser: d_scores = self.get_batch_loss(states, golds, scores, losses) backprop(d_scores, sgd=sgd) # Follow the predicted action - self.transition_batch(states, scores) + self.transition_states(states, scores) states_golds = [eg for eg in states_golds if not eg[0].is_final()] # Do the backprop finish_update(golds, sgd=sgd) return losses + + def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None): + lengths = [len(d) for d in docs] + states = self.moves.init_batch(docs) + for gold in golds: + self.moves.preprocess_gold(gold) + model, finish_update = self.model.begin_update(docs, drop=drop) + states_d_scores, backprops, beams = _beam_utils.update_beam( + self.moves, self.nr_feature, 500, states, golds, model.state2vec, + model.vec2scores, width, drop=drop, losses=losses) + for i, d_scores in enumerate(states_d_scores): + ids, bp_vectors, bp_scores = backprops[i] + d_vector = bp_scores(d_scores, sgd=sgd) + if isinstance(model.ops, CupyOps) \ + and not isinstance(ids, model.state2vec.ops.xp.ndarray): + model.backprops.append(( + util.get_async(model.cuda_stream, ids), + util.get_async(model.cuda_stream, d_vector), + bp_vectors)) + else: + model.backprops.append((ids, d_vector, bp_vectors)) + model.make_updates(sgd) + cdef Beam beam + for beam in beams: + _beam_utils.cleanup_beam(beam) def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500): """Make a square batch, of length equal to the shortest doc. A long diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 2ffaaf30a..29ac6cf82 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -5,9 +5,12 @@ from __future__ import unicode_literals from cpython.ref cimport Py_INCREF from cymem.cymem cimport Pool from thinc.typedefs cimport weight_t +from thinc.extra.search cimport Beam from collections import OrderedDict, Counter import ujson +from . cimport _beam_utils +from ..tokens.doc cimport Doc from ..structs cimport TokenC from .stateclass cimport StateClass from ..typedefs cimport attr_t @@ -57,6 +60,21 @@ cdef class TransitionSystem: offset += len(doc) return states + def init_beams(self, docs, beam_width): + cdef Doc doc + beams = [] + cdef int offset = 0 + for doc in docs: + beam = Beam(self.n_moves, beam_width) + beam.initialize(self.init_beam_state, doc.length, doc.c) + for i in range(beam.width): + state = beam.at(i) + state.offset = offset + offset += len(doc) + beam.check_done(_beam_utils.check_final_state, NULL) + beams.append(beam) + return beams + def get_oracle_sequence(self, doc, GoldParse gold): cdef Pool mem = Pool() costs = mem.alloc(self.n_moves, sizeof(float))