diff --git a/requirements.txt b/requirements.txt index 13c34d601..44f53bdb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ pytest>=4.6.5 pytest-timeout>=1.3.0,<2.0.0 mock>=2.0.0,<3.0.0 flake8>=3.5.0,<3.6.0 +hypothesis diff --git a/setup.py b/setup.py index 6e6e08988..14f8486ca 100755 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ MOD_NAMES = [ "spacy.pipeline._parser_internals._state", "spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.transition_system", + "spacy.pipeline._parser_internals._beam_utils", "spacy.tokenizer", "spacy.training.align", "spacy.training.gold_io", diff --git a/spacy/pipeline/_parser_internals/__init__.pxd b/spacy/pipeline/_parser_internals/__init__.pxd new file mode 100644 index 000000000..e69de29bb diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pxd b/spacy/pipeline/_parser_internals/_beam_utils.pxd new file mode 100644 index 000000000..de3573fbc --- /dev/null +++ b/spacy/pipeline/_parser_internals/_beam_utils.pxd @@ -0,0 +1,6 @@ +from ...typedefs cimport class_t, hash_t + +# These are passed as callbacks to thinc.search.Beam +cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 + +cdef int check_final_state(void* _state, void* extra_args) except -1 diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx new file mode 100644 index 000000000..a7f34daaf --- /dev/null +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -0,0 +1,296 @@ +# cython: infer_types=True +# cython: profile=True +cimport numpy as np +import numpy +from cpython.ref cimport PyObject, Py_XDECREF +from thinc.extra.search cimport Beam +from thinc.extra.search import MaxViolation +from thinc.extra.search cimport MaxViolation + +from ...typedefs cimport hash_t, class_t +from .transition_system cimport TransitionSystem, Transition +from ...errors import Errors +from .stateclass cimport StateC, StateClass + + +# These are passed as callbacks to thinc.search.Beam +cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: + dest = _dest + src = _src + moves = _moves + dest.clone(src) + moves[clas].do(dest, moves[clas].label) + + +cdef int check_final_state(void* _state, void* extra_args) except -1: + state = _state + return state.is_final() + + +cdef class BeamBatch(object): + cdef public TransitionSystem moves + cdef public object states + cdef public object docs + cdef public object golds + cdef public object beams + + def __init__(self, TransitionSystem moves, states, golds, + int width, float density=0.): + cdef StateClass state + self.moves = moves + self.states = states + self.docs = [state.doc for state in states] + self.golds = golds + self.beams = [] + cdef Beam beam + cdef StateC* st + for state in states: + beam = Beam(self.moves.n_moves, width, min_density=density) + beam.initialize(self.moves.init_beam_state, + self.moves.del_beam_state, state.c.length, + state.c._sent) + for i in range(beam.width): + st = beam.at(i) + st.offset = state.c.offset + beam.check_done(check_final_state, NULL) + self.beams.append(beam) + + @property + def is_done(self): + return all(b.is_done for b in self.beams) + + def __getitem__(self, i): + return self.beams[i] + + def __len__(self): + return len(self.beams) + + def get_states(self): + cdef Beam beam + cdef StateC* state + cdef StateClass stcls + states = [] + for beam, doc in zip(self, self.docs): + for i in range(beam.size): + state = beam.at(i) + stcls = StateClass.borrow(state, doc) + states.append(stcls) + return states + + def get_unfinished_states(self): + return [st for st in self.get_states() if not st.is_final()] + + def advance(self, float[:, ::1] scores, follow_gold=False): + cdef Beam beam + cdef int nr_class = scores.shape[1] + cdef const float* c_scores = &scores[0, 0] + docs = self.docs + for i, beam in enumerate(self): + if not beam.is_done: + nr_state = self._set_scores(beam, c_scores, nr_class) + assert nr_state + if self.golds is not None: + self._set_costs( + beam, + docs[i], + self.golds[i], + follow_gold=follow_gold + ) + c_scores += nr_state * nr_class + beam.advance(transition_state, NULL, self.moves.c) + beam.check_done(check_final_state, NULL) + + cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1: + cdef int nr_state = 0 + for i in range(beam.size): + state = beam.at(i) + if not state.is_final(): + for j in range(nr_class): + beam.scores[i][j] = scores[nr_state * nr_class + j] + self.moves.set_valid(beam.is_valid[i], state) + nr_state += 1 + else: + for j in range(beam.nr_class): + beam.scores[i][j] = 0 + beam.costs[i][j] = 0 + return nr_state + + def _set_costs(self, Beam beam, doc, gold, int follow_gold=False): + cdef const StateC* state + for i in range(beam.size): + state = beam.at(i) + if state.is_final(): + for j in range(beam.nr_class): + beam.is_valid[i][j] = 0 + beam.costs[i][j] = 9000 + else: + self.moves.set_costs(beam.is_valid[i], beam.costs[i], + state, gold) + if follow_gold: + min_cost = 0 + for j in range(beam.nr_class): + if beam.is_valid[i][j] and beam.costs[i][j] < min_cost: + min_cost = beam.costs[i][j] + for j in range(beam.nr_class): + if beam.costs[i][j] > min_cost: + beam.is_valid[i][j] = 0 + + +def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0): + cdef MaxViolation violn + pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density) + gbeam = BeamBatch(moves, states, golds, width=width, density=0.0) + cdef StateClass state + beam_maps = [] + backprops = [] + violns = [MaxViolation() for _ in range(len(states))] + dones = [False for _ in states] + while not pbeam.is_done or not gbeam.is_done: + # The beam maps let us find the right row in the flattened scores + # array for each state. States are identified by (example id, + # history). We keep a different beam map for each step (since we'll + # have a flat scores array for each step). The beam map will let us + # take the per-state losses, and compute the gradient for each (step, + # state, class). + # Gather all states from the two beams in a list. Some stats may occur + # in both beams. To figure out which beam each state belonged to, + # we keep two lists of indices, p_indices and g_indices + states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam) + beam_maps.append(beam_map) + if not states: + break + # Now that we have our flat list of states, feed them through the model + scores, bp_scores = model.begin_update(states) + assert scores.size != 0 + # Store the callbacks for the backward pass + backprops.append(bp_scores) + # Unpack the scores for the two beams. The indices arrays + # tell us which example and state the scores-row refers to. + # Now advance the states in the beams. The gold beam is constrained to + # to follow only gold analyses. + if not pbeam.is_done: + pbeam.advance(model.ops.as_contig(scores[p_indices])) + if not gbeam.is_done: + gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True) + # Track the "maximum violation", to use in the update. + for i, violn in enumerate(violns): + if not dones[i]: + violn.check_crf(pbeam[i], gbeam[i]) + if pbeam[i].is_done and gbeam[i].is_done: + dones[i] = True + histories = [] + grads = [] + for violn in violns: + if violn.p_hist: + histories.append(violn.p_hist + violn.g_hist) + d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs] + grads.append(d_loss) + else: + histories.append([]) + grads.append([]) + loss = 0.0 + states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads) + for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)): + loss += (d_scores**2).mean() + bp_scores(d_scores) + return loss + + +def collect_states(beams, docs): + cdef StateClass state + cdef Beam beam + states = [] + for state_or_beam, doc in zip(beams, docs): + if isinstance(state_or_beam, StateClass): + states.append(state_or_beam) + else: + beam = state_or_beam + state = StateClass.borrow(beam.at(0), doc) + states.append(state) + return states + + +def get_unique_states(pbeams, gbeams): + seen = {} + states = [] + p_indices = [] + g_indices = [] + beam_map = {} + docs = pbeams.docs + cdef Beam pbeam, gbeam + if len(pbeams) != len(gbeams): + raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams))) + for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)): + if not pbeam.is_done: + for i in range(pbeam.size): + state = StateClass.borrow(pbeam.at(i), doc) + if not state.is_final(): + key = tuple([eg_id] + pbeam.histories[i]) + if key in seen: + raise ValueError(Errors.E080.format(key=key)) + seen[key] = len(states) + p_indices.append(len(states)) + states.append(state) + beam_map.update(seen) + if not gbeam.is_done: + for i in range(gbeam.size): + state = StateClass.borrow(gbeam.at(i), doc) + if not state.is_final(): + key = tuple([eg_id] + gbeam.histories[i]) + if key in seen: + g_indices.append(seen[key]) + else: + g_indices.append(len(states)) + beam_map[key] = len(states) + states.append(state) + p_indices = numpy.asarray(p_indices, dtype='i') + g_indices = numpy.asarray(g_indices, dtype='i') + return states, p_indices, g_indices, beam_map + + +def get_gradient(nr_class, beam_maps, histories, losses): + """The global model assigns a loss to each parse. The beam scores + are additive, so the same gradient is applied to each action + in the history. This gives the gradient of a single *action* + for a beam state -- so we have "the gradient of loss for taking + action i given history H." + + Histories: Each hitory is a list of actions + Each candidate has a history + Each beam has multiple candidates + Each batch has multiple beams + So history is list of lists of lists of ints + """ + grads = [] + nr_steps = [] + for eg_id, hists in enumerate(histories): + nr_step = 0 + for loss, hist in zip(losses[eg_id], hists): + assert not numpy.isnan(loss) + if loss != 0.0: + nr_step = max(nr_step, len(hist)) + nr_steps.append(nr_step) + for i in range(max(nr_steps)): + grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), + dtype='f')) + if len(histories) != len(losses): + raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses))) + for eg_id, hists in enumerate(histories): + for loss, hist in zip(losses[eg_id], hists): + assert not numpy.isnan(loss) + if loss == 0.0: + continue + key = tuple([eg_id]) + # Adjust loss for length + # We need to do this because each state in a short path is scored + # multiple times, as we add in the average cost when we run out + # of actions. + avg_loss = loss / len(hist) + loss += avg_loss * (nr_steps[eg_id] - len(hist)) + for step, clas in enumerate(hist): + i = beam_maps[step][key] + # In step j, at state i action clas + # resulted in loss + grads[step][i, clas] += loss + key = key + tuple([clas]) + return grads diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index 0d0dd8c05..a6bf926f9 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -1,6 +1,9 @@ from libc.string cimport memcpy, memset from libc.stdlib cimport calloc, free from libc.stdint cimport uint32_t, uint64_t +cimport libcpp +from libcpp.vector cimport vector +from libcpp.set cimport set from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from murmurhash.mrmr cimport hash64 @@ -14,89 +17,48 @@ from ...typedefs cimport attr_t cdef inline bint is_space_token(const TokenC* token) nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) -cdef struct RingBufferC: - int[8] data - int i - int default - -cdef inline int ring_push(RingBufferC* ring, int value) nogil: - ring.data[ring.i] = value - ring.i += 1 - if ring.i >= 8: - ring.i = 0 - -cdef inline int ring_get(RingBufferC* ring, int i) nogil: - if i >= ring.i: - return ring.default - else: - return ring.data[ring.i-i] +cdef struct ArcC: + int head + int child + attr_t label cdef cppclass StateC: - int* _stack - int* _buffer - bint* shifted - TokenC* _sent - SpanC* _ents + int* _heads + const TokenC* _sent + vector[int] _stack + vector[int] _rebuffer + vector[SpanC] _ents + vector[ArcC] _left_arcs + vector[ArcC] _right_arcs + vector[libcpp.bool] _unshiftable + set[int] _sent_starts TokenC _empty_token - RingBufferC _hist int length int offset - int _s_i int _b_i - int _e_i - int _break __init__(const TokenC* sent, int length) nogil: - cdef int PADDING = 5 - this._buffer = calloc(length + (PADDING * 2), sizeof(int)) - this._stack = calloc(length + (PADDING * 2), sizeof(int)) - this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) - this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) - this._ents = calloc(length + (PADDING * 2), sizeof(SpanC)) - if not (this._buffer and this._stack and this.shifted - and this._sent and this._ents): + this._sent = sent + this._heads = calloc(length, sizeof(int)) + if not (this._sent and this._heads): with gil: PyErr_SetFromErrno(MemoryError) PyErr_CheckSignals() - memset(&this._hist, 0, sizeof(this._hist)) this.offset = 0 - cdef int i - for i in range(length + (PADDING * 2)): - this._ents[i].end = -1 - this._sent[i].l_edge = i - this._sent[i].r_edge = i - for i in range(PADDING): - this._sent[i].lex = &EMPTY_LEXEME - this._sent += PADDING - this._ents += PADDING - this._buffer += PADDING - this._stack += PADDING - this.shifted += PADDING this.length = length - this._break = -1 - this._s_i = 0 this._b_i = 0 - this._e_i = 0 for i in range(length): - this._buffer[i] = i + this._heads[i] = -1 + this._unshiftable.push_back(0) memset(&this._empty_token, 0, sizeof(TokenC)) this._empty_token.lex = &EMPTY_LEXEME - for i in range(length): - this._sent[i] = sent[i] - this._buffer[i] = i - for i in range(length, length+PADDING): - this._sent[i].lex = &EMPTY_LEXEME __dealloc__(): - cdef int PADDING = 5 - free(this._sent - PADDING) - free(this._ents - PADDING) - free(this._buffer - PADDING) - free(this._stack - PADDING) - free(this.shifted - PADDING) + free(this._heads) void set_context_tokens(int* ids, int n) nogil: + cdef int i, j if n == 1: if this.B(0) >= 0: ids[0] = this.B(0) @@ -145,22 +107,18 @@ cdef cppclass StateC: ids[11] = this.R(this.S(1), 1) ids[12] = this.R(this.S(1), 2) elif n == 6: + for i in range(6): + ids[i] = -1 if this.B(0) >= 0: ids[0] = this.B(0) - ids[1] = this.B(0)-1 - else: - ids[0] = -1 - ids[1] = -1 - ids[2] = this.B(1) - ids[3] = this.E(0) - if ids[3] >= 1: - ids[4] = this.E(0)-1 - else: - ids[4] = -1 - if (ids[3]+1) < this.length: - ids[5] = this.E(0)+1 - else: - ids[5] = -1 + if this.entity_is_open(): + ent = this.get_ent() + j = 1 + for i in range(ent.start, this.B(0)): + ids[j] = i + j += 1 + if j >= 6: + break else: # TODO error =/ pass @@ -171,329 +129,256 @@ cdef cppclass StateC: ids[i] = -1 int S(int i) nogil const: - if i >= this._s_i: + if i >= this._stack.size(): return -1 - return this._stack[this._s_i - (i+1)] + elif i < 0: + return -1 + return this._stack.at(this._stack.size() - (i+1)) int B(int i) nogil const: - if (i + this._b_i) >= this.length: + if i < 0: return -1 - return this._buffer[this._b_i + i] - - const TokenC* S_(int i) nogil const: - return this.safe_get(this.S(i)) + elif i < this._rebuffer.size(): + return this._rebuffer.at(this._rebuffer.size() - (i+1)) + else: + b_i = this._b_i + (i - this._rebuffer.size()) + if b_i >= this.length: + return -1 + else: + return b_i const TokenC* B_(int i) nogil const: return this.safe_get(this.B(i)) - const TokenC* H_(int i) nogil const: - return this.safe_get(this.H(i)) - const TokenC* E_(int i) nogil const: return this.safe_get(this.E(i)) - const TokenC* L_(int i, int idx) nogil const: - return this.safe_get(this.L(i, idx)) - - const TokenC* R_(int i, int idx) nogil const: - return this.safe_get(this.R(i, idx)) - const TokenC* safe_get(int i) nogil const: if i < 0 or i >= this.length: return &this._empty_token else: return &this._sent[i] - int H(int i) nogil const: - if i < 0 or i >= this.length: + void get_arcs(vector[ArcC]* arcs) nogil const: + for i in range(this._left_arcs.size()): + arc = this._left_arcs.at(i) + if arc.head != -1 and arc.child != -1: + arcs.push_back(arc) + for i in range(this._right_arcs.size()): + arc = this._right_arcs.at(i) + if arc.head != -1 and arc.child != -1: + arcs.push_back(arc) + + int H(int child) nogil const: + if child >= this.length or child < 0: return -1 - return this._sent[i].head + i + else: + return this._heads[child] int E(int i) nogil const: - if this._e_i <= 0 or this._e_i >= this.length: + if this._ents.size() == 0: return -1 - if i < 0 or i >= this._e_i: - return -1 - return this._ents[this._e_i - (i+1)].start + else: + return this._ents.back().start - int L(int i, int idx) nogil const: - if idx < 1: + int L(int head, int idx) nogil const: + if idx < 1 or this._left_arcs.size() == 0: return -1 - if i < 0 or i >= this.length: + cdef vector[int] lefts + for i in range(this._left_arcs.size()): + arc = this._left_arcs.at(i) + if arc.head == head and arc.child != -1 and arc.child < head: + lefts.push_back(arc.child) + idx = (lefts.size()) - idx + if idx < 0: return -1 - cdef const TokenC* target = &this._sent[i] - if target.l_kids < idx: - return -1 - cdef const TokenC* ptr = &this._sent[target.l_edge] + else: + return lefts.at(idx) - while ptr < target: - # If this head is still to the right of us, we can skip to it - # No token that's between this token and this head could be our - # child. - if (ptr.head >= 1) and (ptr + ptr.head) < target: - ptr += ptr.head - - elif ptr + ptr.head == target: - idx -= 1 - if idx == 0: - return ptr - this._sent - ptr += 1 - else: - ptr += 1 - return -1 - - int R(int i, int idx) nogil const: - if idx < 1: + int R(int head, int idx) nogil const: + if idx < 1 or this._right_arcs.size() == 0: return -1 - if i < 0 or i >= this.length: + cdef vector[int] rights + for i in range(this._right_arcs.size()): + arc = this._right_arcs.at(i) + if arc.head == head and arc.child != -1 and arc.child > head: + rights.push_back(arc.child) + idx = (rights.size()) - idx + if idx < 0: return -1 - cdef const TokenC* target = &this._sent[i] - if target.r_kids < idx: - return -1 - cdef const TokenC* ptr = &this._sent[target.r_edge] - while ptr > target: - # If this head is still to the right of us, we can skip to it - # No token that's between this token and this head could be our - # child. - if (ptr.head < 0) and ((ptr + ptr.head) > target): - ptr += ptr.head - elif ptr + ptr.head == target: - idx -= 1 - if idx == 0: - return ptr - this._sent - ptr -= 1 - else: - ptr -= 1 - return -1 + else: + return rights.at(idx) bint empty() nogil const: - return this._s_i <= 0 + return this._stack.size() == 0 bint eol() nogil const: return this.buffer_length() == 0 - bint at_break() nogil const: - return this._break != -1 - bint is_final() nogil const: - return this.stack_depth() <= 0 and this._b_i >= this.length + return this.stack_depth() <= 0 and this.eol() - bint has_head(int i) nogil const: - return this.safe_get(i).head != 0 + int cannot_sent_start(int word) nogil const: + if word < 0 or word >= this.length: + return 0 + elif this._sent[word].sent_start == -1: + return 1 + else: + return 0 - int n_L(int i) nogil const: - return this.safe_get(i).l_kids + int is_sent_start(int word) nogil const: + if word < 0 or word >= this.length: + return 0 + elif this._sent[word].sent_start == 1: + return 1 + elif this._sent_starts.count(word) >= 1: + return 1 + else: + return 0 - int n_R(int i) nogil const: - return this.safe_get(i).r_kids + void set_sent_start(int word, int value) nogil: + if value >= 1: + this._sent_starts.insert(word) + + bint has_head(int child) nogil const: + return this._heads[child] >= 0 + + int l_edge(int word) nogil const: + return word + + int r_edge(int word) nogil const: + return word + + int n_L(int head) nogil const: + cdef int n = 0 + for i in range(this._left_arcs.size()): + arc = this._left_arcs.at(i) + if arc.head == head and arc.child != -1 and arc.child < arc.head: + n += 1 + return n + + int n_R(int head) nogil const: + cdef int n = 0 + for i in range(this._right_arcs.size()): + arc = this._right_arcs.at(i) + if arc.head == head and arc.child != -1 and arc.child > arc.head: + n += 1 + return n bint stack_is_connected() nogil const: return False bint entity_is_open() nogil const: - if this._e_i < 1: + if this._ents.size() == 0: return False - return this._ents[this._e_i-1].end == -1 + else: + return this._ents.back().end == -1 int stack_depth() nogil const: - return this._s_i + return this._stack.size() int buffer_length() nogil const: - if this._break != -1: - return this._break - this._b_i - else: - return this.length - this._b_i - - uint64_t hash() nogil const: - cdef TokenC[11] sig - sig[0] = this.S_(2)[0] - sig[1] = this.S_(1)[0] - sig[2] = this.R_(this.S(1), 1)[0] - sig[3] = this.L_(this.S(0), 1)[0] - sig[4] = this.L_(this.S(0), 2)[0] - sig[5] = this.S_(0)[0] - sig[6] = this.R_(this.S(0), 2)[0] - sig[7] = this.R_(this.S(0), 1)[0] - sig[8] = this.B_(0)[0] - sig[9] = this.E_(0)[0] - sig[10] = this.E_(1)[0] - return hash64(sig, sizeof(sig), this._s_i) \ - + hash64(&this._hist, sizeof(RingBufferC), 1) - - void push_hist(int act) nogil: - ring_push(&this._hist, act+1) - - int get_hist(int i) nogil: - return ring_get(&this._hist, i) + return this.length - this._b_i void push() nogil: - if this.B(0) != -1: - this._stack[this._s_i] = this.B(0) - this._s_i += 1 - this._b_i += 1 - if this.safe_get(this.B_(0).l_edge).sent_start == 1: - this.set_break(this.B_(0).l_edge) - if this._b_i > this._break: - this._break = -1 + b0 = this.B(0) + if this._rebuffer.size(): + b0 = this._rebuffer.back() + this._rebuffer.pop_back() + else: + b0 = this._b_i + this._b_i += 1 + this._stack.push_back(b0) void pop() nogil: - if this._s_i >= 1: - this._s_i -= 1 + this._stack.pop_back() void force_final() nogil: # This should only be used in desperate situations, as it may leave # the analysis in an unexpected state. - this._s_i = 0 + this._stack.clear() this._b_i = this.length void unshift() nogil: - this._b_i -= 1 - this._buffer[this._b_i] = this.S(0) - this._s_i -= 1 - this.shifted[this.B(0)] = True + s0 = this._stack.back() + this._unshiftable[s0] = 1 + this._rebuffer.push_back(s0) + this._stack.pop_back() + + int is_unshiftable(int item) nogil const: + if item >= this._unshiftable.size(): + return 0 + else: + return this._unshiftable.at(item) + + void set_reshiftable(int item) nogil: + if item < this._unshiftable.size(): + this._unshiftable[item] = 0 void add_arc(int head, int child, attr_t label) nogil: if this.has_head(child): this.del_arc(this.H(child), child) - - cdef int dist = head - child - this._sent[child].head = dist - this._sent[child].dep = label - cdef int i - if child > head: - this._sent[head].r_kids += 1 - # Some transition systems can have a word in the buffer have a - # rightward child, e.g. from Unshift. - this._sent[head].r_edge = this._sent[child].r_edge - i = 0 - while this.has_head(head) and i < this.length: - head = this.H(head) - this._sent[head].r_edge = this._sent[child].r_edge - i += 1 # Guard against infinite loops + cdef ArcC arc + arc.head = head + arc.child = child + arc.label = label + if head > child: + this._left_arcs.push_back(arc) else: - this._sent[head].l_kids += 1 - this._sent[head].l_edge = this._sent[child].l_edge + this._right_arcs.push_back(arc) + this._heads[child] = head void del_arc(int h_i, int c_i) nogil: - cdef int dist = h_i - c_i - cdef TokenC* h = &this._sent[h_i] - cdef int i = 0 - if c_i > h_i: - # this.R_(h_i, 2) returns the second-rightmost child token of h_i - # If we have more than 2 rightmost children, our 2nd rightmost child's - # rightmost edge is going to be our new rightmost edge. - h.r_edge = this.R_(h_i, 2).r_edge if h.r_kids >= 2 else h_i - h.r_kids -= 1 - new_edge = h.r_edge - # Correct upwards in the tree --- see Issue #251 - while h.head < 0 and i < this.length: # Guard infinite loop - h += h.head - h.r_edge = new_edge - i += 1 + cdef vector[ArcC]* arcs + if h_i > c_i: + arcs = &this._left_arcs else: - # Same logic applies for left edge, but we don't need to walk up - # the tree, as the head is off the stack. - h.l_edge = this.L_(h_i, 2).l_edge if h.l_kids >= 2 else h_i - h.l_kids -= 1 + arcs = &this._right_arcs + if arcs.size() == 0: + return + arc = arcs.back() + if arc.head == h_i and arc.child == c_i: + arcs.pop_back() + else: + for i in range(arcs.size()-1): + arc = arcs.at(i) + if arc.head == h_i and arc.child == c_i: + arc.head = -1 + arc.child = -1 + arc.label = 0 + break + + SpanC get_ent() nogil const: + cdef SpanC ent + if this._ents.size() == 0: + ent.start = 0 + ent.end = 0 + ent.label = 0 + return ent + else: + return this._ents.back() void open_ent(attr_t label) nogil: - this._ents[this._e_i].start = this.B(0) - this._ents[this._e_i].label = label - this._ents[this._e_i].end = -1 - this._e_i += 1 + cdef SpanC ent + ent.start = this.B(0) + ent.label = label + ent.end = -1 + this._ents.push_back(ent) void close_ent() nogil: - # Note that we don't decrement _e_i here! We want to maintain all - # entities, not over-write them... - this._ents[this._e_i-1].end = this.B(0)+1 - this._sent[this.B(0)].ent_iob = 1 - - void set_ent_tag(int i, int ent_iob, attr_t ent_type) nogil: - if 0 <= i < this.length: - this._sent[i].ent_iob = ent_iob - this._sent[i].ent_type = ent_type - - void set_break(int i) nogil: - if 0 <= i < this.length: - this._sent[i].sent_start = 1 - this._break = this._b_i + this._ents.back().end = this.B(0)+1 void clone(const StateC* src) nogil: this.length = src.length - memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) - memcpy(this._stack, src._stack, this.length * sizeof(int)) - memcpy(this._buffer, src._buffer, this.length * sizeof(int)) - memcpy(this._ents, src._ents, this.length * sizeof(SpanC)) - memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) + this._sent = src._sent + this._stack = src._stack + this._rebuffer = src._rebuffer + this._sent_starts = src._sent_starts + this._unshiftable = src._unshiftable + memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0])) + this._ents = src._ents + this._left_arcs = src._left_arcs + this._right_arcs = src._right_arcs this._b_i = src._b_i - this._s_i = src._s_i - this._e_i = src._e_i - this._break = src._break this.offset = src.offset this._empty_token = src._empty_token - - void fast_forward() nogil: - # space token attachement policy: - # - attach space tokens always to the last preceding real token - # - except if it's the beginning of a sentence, then attach to the first following - # - boundary case: a document containing multiple space tokens but nothing else, - # then make the last space token the head of all others - - while is_space_token(this.B_(0)) \ - or this.buffer_length() == 0 \ - or this.stack_depth() == 0: - if this.buffer_length() == 0: - # remove the last sentence's root from the stack - if this.stack_depth() == 1: - this.pop() - # parser got stuck: reduce stack or unshift - elif this.stack_depth() > 1: - if this.has_head(this.S(0)): - this.pop() - else: - this.unshift() - # stack is empty but there is another sentence on the buffer - elif (this.length - this._b_i) >= 1: - this.push() - else: # stack empty and nothing else coming - break - - elif is_space_token(this.B_(0)): - # the normal case: we're somewhere inside a sentence - if this.stack_depth() > 0: - # assert not is_space_token(this.S_(0)) - # attach all coming space tokens to their last preceding - # real token (which should be on the top of the stack) - while is_space_token(this.B_(0)): - this.add_arc(this.S(0),this.B(0),0) - this.push() - this.pop() - # the rare case: we're at the beginning of a document: - # space tokens are attached to the first real token on the buffer - elif this.stack_depth() == 0: - # store all space tokens on the stack until a real token shows up - # or the last token on the buffer is reached - while is_space_token(this.B_(0)) and this.buffer_length() > 1: - this.push() - # empty the stack by attaching all space tokens to the - # first token on the buffer - # boundary case: if all tokens are space tokens, the last one - # becomes the head of all others - while this.stack_depth() > 0: - this.add_arc(this.B(0),this.S(0),0) - this.pop() - # move the first token onto the stack - this.push() - - elif this.stack_depth() == 0: - # for one token sentences (?) - if this.buffer_length() == 1: - this.push() - this.pop() - # with an empty stack and a non-empty buffer - # only shift is valid anyway - elif (this.length - this._b_i) >= 1: - this.push() - - else: # can this even happen? - break diff --git a/spacy/pipeline/_parser_internals/arc_eager.pxd b/spacy/pipeline/_parser_internals/arc_eager.pxd index e05a34f56..3732dd1b7 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pxd +++ b/spacy/pipeline/_parser_internals/arc_eager.pxd @@ -1,11 +1,7 @@ -from .stateclass cimport StateClass +from ._state cimport StateC from ...typedefs cimport weight_t, attr_t from .transition_system cimport Transition, TransitionSystem cdef class ArcEager(TransitionSystem): pass - - -cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil -cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 69f015bda..cddb6cbd9 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -14,16 +14,11 @@ from ._state cimport StateC from ...errors import Errors -# Calculate cost as gold/not gold. We don't use scalar value anyway. -cdef int BINARY_COSTS = 1 cdef weight_t MIN_SCORE = -90000 cdef attr_t SUBTOK_LABEL = hash_string(u'subtok') DEF NON_MONOTONIC = True -DEF USE_BREAK = True -# Break transition from here -# http://www.aclweb.org/anthology/P13-1074 cdef enum: SHIFT REDUCE @@ -61,9 +56,11 @@ cdef struct GoldParseStateC: int32_t* n_kids int32_t length int32_t stride + weight_t push_cost + weight_t pop_cost -cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, +cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state, heads, labels, sent_starts) except *: cdef GoldParseStateC gs gs.length = len(heads) @@ -142,10 +139,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, if head != i: gs.kids[head][js[head]] = i js[head] += 1 + gs.push_cost = push_cost(state, &gs) + gs.pop_cost = pop_cost(state, &gs) return gs -cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: +cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil: for i in range(gs.length): gs.state_bits[i] = set_state_flag( gs.state_bits[i], @@ -160,9 +159,9 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: gs.n_kids_in_stack[i] = 0 gs.n_kids_in_buffer[i] = 0 - for i in range(stcls.stack_depth()): - s_i = stcls.S(i) - if not is_head_unknown(gs, s_i): + for i in range(s.stack_depth()): + s_i = s.S(i) + if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i: gs.n_kids_in_stack[gs.heads[s_i]] += 1 for kid in gs.kids[s_i][:gs.n_kids[s_i]]: gs.state_bits[kid] = set_state_flag( @@ -170,9 +169,11 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: HEAD_IN_STACK, 1 ) - for i in range(stcls.buffer_length()): - b_i = stcls.B(i) - if not is_head_unknown(gs, b_i): + for i in range(s.buffer_length()): + b_i = s.B(i) + if s.is_sent_start(b_i): + break + if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i: gs.n_kids_in_buffer[gs.heads[b_i]] += 1 for kid in gs.kids[b_i][:gs.n_kids[b_i]]: gs.state_bits[kid] = set_state_flag( @@ -180,6 +181,8 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: HEAD_IN_BUFFER, 1 ) + gs.push_cost = push_cost(s, gs) + gs.pop_cost = pop_cost(s, gs) cdef class ArcEagerGold: @@ -191,17 +194,17 @@ cdef class ArcEagerGold: heads, labels = example.get_aligned_parse(projectivize=True) labels = [label if label is not None else "" for label in labels] labels = [example.x.vocab.strings.add(label) for label in labels] - sent_starts = example.get_aligned("SENT_START") - assert len(heads) == len(labels) == len(sent_starts) - self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) + sent_starts = example.get_aligned_sent_starts() + assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts)) + self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts) def update(self, StateClass stcls): - update_gold_state(&self.c, stcls) + update_gold_state(&self.c, stcls.c) cdef int check_state_gold(char state_bits, char flag) nogil: cdef char one = 1 - return state_bits & (one << flag) + return 1 if (state_bits & (one << flag)) else 0 cdef int set_state_flag(char state_bits, char flag, int value) nogil: @@ -232,41 +235,30 @@ cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil: # Helper functions for the arc-eager oracle -cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: - gold = _gold +cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil: cdef weight_t cost = 0 - if is_head_in_stack(gold, target): + b0 = state.B(0) + if b0 < 0: + return 9000 + if is_head_in_stack(gold, b0): cost += 1 - cost += gold.n_kids_in_stack[target] - if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: + cost += gold.n_kids_in_stack[b0] + if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)): cost += 1 return cost -cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil: - gold = _gold +cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil: cdef weight_t cost = 0 - if is_head_in_buffer(gold, target): - cost += 1 - cost += gold[0].n_kids_in_buffer[target] - if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: + s0 = state.S(0) + if s0 < 0: + return 9000 + if is_head_in_buffer(gold, s0): cost += 1 + cost += gold.n_kids_in_buffer[s0] return cost -cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil: - gold = _gold - if arc_is_gold(gold, head, child): - return 0 - elif stcls.H(child) == gold.heads[child]: - return 1 - # Head in buffer - elif is_head_in_buffer(gold, child): - return 1 - else: - return 0 - - cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: if is_head_unknown(gold, child): return True @@ -276,7 +268,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: return False -cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil: +cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil: if is_head_unknown(gold, child): return True elif label == 0: @@ -292,218 +284,251 @@ cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil: cdef class Shift: + """Move the first word of the buffer onto the stack and mark it as "shifted" + + Validity: + * If stack is empty + * At least two words in sentence + * Word has not been shifted before + + Cost: push_cost + + Action: + * Mark B[0] as 'shifted' + * Push stack + * Advance buffer + """ @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - sent_start = st._sent[st.B_(0).l_edge].sent_start - return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1 + if st.stack_depth() == 0: + return 1 + elif st.buffer_length() < 2: + return 0 + elif st.is_sent_start(st.B(0)): + return 0 + elif st.is_unshiftable(st.B(0)): + return 0 + else: + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.push() - st.fast_forward() @staticmethod - cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: gold = _gold - return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: - gold = _gold - return push_cost(s, gold, s.B(0)) - - @staticmethod - cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil: - return 0 + return gold.push_cost cdef class Reduce: + """ + Pop from the stack. If it has no head and the stack isn't empty, place + it back on the buffer. + + Validity: + * Stack not empty + * Buffer nt empty + * Stack depth 1 and cannot sent start l_edge(st.B(0)) + + Cost: + * If B[0] is the start of a sentence, cost is 0 + * Arcs between stack and buffer + * If arc has no head, we're saving arcs between S[0] and S[1:], so decrement + cost by those arcs. + """ @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - return st.stack_depth() >= 2 - - @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: - if st.has_head(st.S(0)): - st.pop() - else: - st.unshift() - st.fast_forward() - - @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil: - gold = _gold - s0 = st.S(0) - cost = pop_cost(st, gold, s0) - return_to_buffer = not st.has_head(s0) - if return_to_buffer: - # Decrement cost for the arcs we save, as we'll be putting this - # back to the buffer - if is_head_in_stack(gold, s0): - cost -= 1 - cost -= gold.n_kids_in_stack[s0] - if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: - cost -= 1 - return cost - - @staticmethod - cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil: - return 0 - - -cdef class LeftArc: - @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: - if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1): - return 0 - sent_start = st._sent[st.B_(0).l_edge].sent_start - return sent_start != 1 - - @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: - st.add_arc(st.B(0), st.S(0), label) - st.pop() - st.fast_forward() - - @staticmethod - cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: - cdef weight_t cost = 0 - s0 = s.S(0) - b0 = s.B(0) - if arc_is_gold(gold, b0, s0): - # Have a negative cost if we 'recover' from the wrong dependency - return 0 if not s.has_head(s0) else -1 - else: - # Account for deps we might lose between S0 and stack - if not s.has_head(s0): - cost += gold.n_kids_in_stack[s0] - if is_head_in_buffer(gold, s0): - cost += 1 - return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) - - @staticmethod - cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: - return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) - - -cdef class RightArc: - @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: - # If there's (perhaps partial) parse pre-set, don't allow cycle. - if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1): - return 0 - sent_start = st._sent[st.B_(0).l_edge].sent_start - return sent_start != 1 and st.H(st.S(0)) != st.B(0) - - @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: - st.add_arc(st.S(0), st.B(0), label) - st.push() - st.fast_forward() - - @staticmethod - cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: - gold = _gold - if arc_is_gold(gold, s.S(0), s.B(0)): - return 0 - elif s.c.shifted[s.B(0)]: - return push_cost(s, gold, s.B(0)) - else: - return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) - - @staticmethod - cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) - - -cdef class Break: - @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: - cdef int i - if not USE_BREAK: + if st.stack_depth() == 0: return False - elif st.at_break(): - return False - elif st.stack_depth() < 1: - return False - elif st.B_(0).l_edge < 0: - return False - elif st._sent[st.B_(0).l_edge].sent_start < 0: + elif st.buffer_length() == 0: + return True + elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))): return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - st.set_break(st.B_(0).l_edge) - st.fast_forward() - - @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: - gold = _gold - return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: - gold = _gold - cost = 0 - for i in range(s.stack_depth()): - S_i = s.S(i) - cost += gold.n_kids_in_buffer[S_i] - if is_head_in_buffer(gold, S_i): - cost += 1 - # It's weird not to check the gold sentence boundaries but if we do, - # we can't account for "sunk costs", i.e. situations where we're already - # wrong. - s0_root = _get_root(s.S(0), gold) - b0_root = _get_root(s.B(0), gold) - if s0_root != b0_root or s0_root == -1 or b0_root == -1: - return cost + if st.has_head(st.S(0)) or st.stack_depth() == 1: + st.pop() else: - return cost + 1 + st.unshift() @staticmethod - cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil: - return 0 + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + gold = _gold + if state.is_sent_start(state.B(0)): + return 0 + s0 = state.S(0) + cost = gold.pop_cost + if not state.has_head(s0): + # Decrement cost for the arcs we save, as we'll be putting this + # back to the buffer + if is_head_in_stack(gold, s0): + cost -= 1 + cost -= gold.n_kids_in_stack[s0] + return cost -cdef int _get_root(int word, const GoldParseStateC* gold) nogil: - if is_head_unknown(gold, word): - return -1 - while gold.heads[word] != word and word >= 0: - word = gold.heads[word] - if is_head_unknown(gold, word): - return -1 - else: - return word + +cdef class LeftArc: + """Add an arc between B[0] and S[0], replacing the previous head of S[0] if + one is set. Pop S[0] from the stack. + + Validity: + * len(S) >= 1 + * len(B) >= 1 + * not is_sent_start(B[0]) + + Cost: + pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0])) + """ + @staticmethod + cdef bint is_valid(const StateC* st, attr_t label) nogil: + if st.stack_depth() == 0: + return 0 + elif st.buffer_length() == 0: + return 0 + elif st.is_sent_start(st.B(0)): + return 0 + elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1): + return 0 + else: + return 1 + + @staticmethod + cdef int transition(StateC* st, attr_t label) nogil: + st.add_arc(st.B(0), st.S(0), label) + # If we change the stack, it's okay to remove the shifted mark, as + # we can't get in an infinite loop this way. + st.set_reshiftable(st.B(0)) + st.pop() + + @staticmethod + cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + gold = _gold + cdef weight_t cost = gold.pop_cost + s0 = state.S(0) + s1 = state.S(1) + b0 = state.B(0) + if state.has_head(s0): + # Increment cost if we're clobbering a correct arc + cost += gold.heads[s0] == s1 + else: + # If there's no head, we're losing arcs between S0 and S[1:]. + cost += is_head_in_stack(gold, s0) + cost += gold.n_kids_in_stack[s0] + if b0 != -1 and s0 != -1 and gold.heads[s0] == b0: + cost -= 1 + cost += not label_is_gold(gold, s0, label) + return cost + + +cdef class RightArc: + """ + Add an arc from S[0] to B[0]. Push B[0]. + + Validity: + * len(S) >= 1 + * len(B) >= 1 + * not is_sent_start(B[0]) + + Cost: + push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label) + """ + @staticmethod + cdef bint is_valid(const StateC* st, attr_t label) nogil: + if st.stack_depth() == 0: + return 0 + elif st.buffer_length() == 0: + return 0 + elif st.is_sent_start(st.B(0)): + return 0 + elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1): + # If there's (perhaps partial) parse pre-set, don't allow cycle. + return 0 + else: + return 1 + + @staticmethod + cdef int transition(StateC* st, attr_t label) nogil: + st.add_arc(st.S(0), st.B(0), label) + st.push() + + @staticmethod + cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + gold = _gold + cost = gold.push_cost + s0 = state.S(0) + b0 = state.B(0) + if s0 != -1 and b0 != -1 and gold.heads[b0] == s0: + cost -= 1 + cost += not label_is_gold(gold, b0, label) + elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0): + cost += 1 + return cost + + +cdef class Break: + """Mark the second word of the buffer as the start of a + sentence. + + Validity: + * len(buffer) >= 2 + * B[1] == B[0] + 1 + * not is_sent_start(B[1]) + * not cannot_sent_start(B[1]) + + Action: + * mark_sent_start(B[1]) + + Cost: + * not is_sent_start(B[1]) + * Arcs between B[0] and B[1:] + * Arcs between S and B[1] + """ + @staticmethod + cdef bint is_valid(const StateC* st, attr_t label) nogil: + cdef int i + if st.buffer_length() < 2: + return False + elif st.B(1) != st.B(0) + 1: + return False + elif st.is_sent_start(st.B(1)): + return False + elif st.cannot_sent_start(st.B(1)): + return False + else: + return True + + @staticmethod + cdef int transition(StateC* st, attr_t label) nogil: + st.set_sent_start(st.B(1), 1) + + @staticmethod + cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil: + gold = _gold + cdef int b0 = state.B(0) + cdef int cost = 0 + cdef int si + for i in range(state.stack_depth()): + si = state.S(i) + if is_head_in_buffer(gold, si): + cost += 1 + cost += gold.n_kids_in_buffer[si] + # We need to score into B[1:], so subtract deps that are at b0 + if gold.heads[b0] == si: + cost -= 1 + if gold.heads[si] == b0: + cost -= 1 + if not is_sent_start(gold, state.B(1)) \ + and not is_sent_start_unknown(gold, state.B(1)): + cost += 1 + return cost cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: st = new StateC(tokens, length) - for i in range(st.length): - if st._sent[i].dep == 0: - st._sent[i].l_edge = i - st._sent[i].r_edge = i - st._sent[i].head = 0 - st._sent[i].dep = 0 - st._sent[i].l_kids = 0 - st._sent[i].r_kids = 0 - st.fast_forward() return st @@ -515,6 +540,8 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1: cdef class ArcEager(TransitionSystem): def __init__(self, *args, **kwargs): TransitionSystem.__init__(self, *args, **kwargs) + self.init_beam_state = _init_state + self.del_beam_state = _del_state @classmethod def get_actions(cls, **kwargs): @@ -537,7 +564,7 @@ cdef class ArcEager(TransitionSystem): label = 'ROOT' if head == child: actions[BREAK][label] += 1 - elif head < child: + if head < child: actions[RIGHT][label] += 1 actions[REDUCE][''] += 1 elif head > child: @@ -567,8 +594,14 @@ cdef class ArcEager(TransitionSystem): t.do(state.c, t.label) return state - def is_gold_parse(self, StateClass state, gold): - raise NotImplementedError + def is_gold_parse(self, StateClass state, ArcEagerGold gold): + for i in range(state.c.length): + token = state.c.safe_get(i) + if not arc_is_gold(&gold.c, i, i+token.head): + return False + elif not label_is_gold(&gold.c, i, token.dep): + return False + return True def init_gold(self, StateClass state, Example example): gold = ArcEagerGold(self, state, example) @@ -576,6 +609,7 @@ cdef class ArcEager(TransitionSystem): return gold def init_gold_batch(self, examples): + # TODO: Projectivitity? all_states = self.init_batch([eg.predicted for eg in examples]) golds = [] states = [] @@ -662,24 +696,13 @@ cdef class ArcEager(TransitionSystem): raise ValueError(Errors.E019.format(action=move, src='arc_eager')) return t - cdef int initialize_state(self, StateC* st) nogil: - for i in range(st.length): - if st._sent[i].dep == 0: - st._sent[i].l_edge = i - st._sent[i].r_edge = i - st._sent[i].head = 0 - st._sent[i].dep = 0 - st._sent[i].l_kids = 0 - st._sent[i].r_kids = 0 - st.fast_forward() - - cdef int finalize_state(self, StateC* st) nogil: - cdef int i - for i in range(st.length): - if st._sent[i].head == 0: - st._sent[i].dep = self.root_label - - def finalize_doc(self, Doc doc): + def set_annotations(self, StateClass state, Doc doc): + for arc in state.arcs: + doc.c[arc["child"]].head = arc["head"] - arc["child"] + doc.c[arc["child"]].dep = arc["label"] + for i in range(doc.length): + if doc.c[i].head == 0: + doc.c[i].dep = self.root_label set_children_from_heads(doc.c, 0, doc.length) def has_gold(self, Example eg, start=0, end=None): @@ -690,7 +713,7 @@ cdef class ArcEager(TransitionSystem): return False cdef int set_valid(self, int* output, const StateC* st) nogil: - cdef bint[N_MOVES] is_valid + cdef int[N_MOVES] is_valid is_valid[SHIFT] = Shift.is_valid(st, 0) is_valid[REDUCE] = Reduce.is_valid(st, 0) is_valid[LEFT] = LeftArc.is_valid(st, 0) @@ -710,29 +733,31 @@ cdef class ArcEager(TransitionSystem): gold_state = gold_.c n_gold = 0 if self.c[i].is_valid(stcls.c, self.c[i].label): - cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label) else: cost = 9000 return cost cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass stcls, gold) except -1: + const StateC* state, gold) except -1: if not isinstance(gold, ArcEagerGold): raise TypeError(Errors.E909.format(name="ArcEagerGold")) cdef ArcEagerGold gold_ = gold - gold_.update(stcls) gold_state = gold_.c + update_gold_state(&gold_state, state) + self.set_valid(is_valid, state) cdef int n_gold = 0 for i in range(self.n_moves): - if self.c[i].is_valid(stcls.c, self.c[i].label): - is_valid[i] = True - costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + if is_valid[i]: + costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label) if costs[i] <= 0: n_gold += 1 else: - is_valid[i] = False costs[i] = 9000 if n_gold < 1: + for i in range(self.n_moves): + print(self.get_class_name(i), is_valid[i], costs[i]) + print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1))) raise ValueError def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None): @@ -748,12 +773,13 @@ cdef class ArcEager(TransitionSystem): failed = False while not state.is_final(): try: - self.set_costs(is_valid, costs, state, gold) + self.set_costs(is_valid, costs, state.c, gold) except ValueError: failed = True break + min_cost = min(costs[i] for i in range(self.n_moves)) for i in range(self.n_moves): - if is_valid[i] and costs[i] <= 0: + if is_valid[i] and costs[i] <= min_cost: action = self.c[i] history.append(i) s0 = state.S(0) @@ -762,9 +788,7 @@ cdef class ArcEager(TransitionSystem): example = _debug debug_log.append(" ".join(( self.get_class_name(i), - "S0=", (example.x[s0].text if s0 >= 0 else "__"), - "B0=", (example.x[b0].text if b0 >= 0 else "__"), - "S0 head?", str(state.has_head(state.S(0))), + state.print_state() ))) action.do(state.c, action.label) break @@ -783,6 +807,8 @@ cdef class ArcEager(TransitionSystem): print("Aligned heads") for i, head in enumerate(aligned_heads): print(example.x[i], example.x[head] if head is not None else "__") + print("Aligned sent starts") + print(example.get_aligned_sent_starts()) print("Predicted tokens") print([(w.i, w.text) for w in example.x]) diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index 4f142caaf..7f4d332db 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -3,9 +3,12 @@ from cymem.cymem cimport Pool from collections import Counter +from ...tokens.doc cimport Doc +from ...tokens.span import Span from ...typedefs cimport weight_t, attr_t from ...lexeme cimport Lexeme from ...attrs cimport IS_SPACE +from ...structs cimport TokenC from ...training.example cimport Example from .stateclass cimport StateClass from ._state cimport StateC @@ -46,17 +49,17 @@ cdef class BiluoGold: def __init__(self, BiluoPushDown moves, StateClass stcls, Example example): self.mem = Pool() - self.c = create_gold_state(self.mem, moves, stcls, example) + self.c = create_gold_state(self.mem, moves, stcls.c, example) def update(self, StateClass stcls): - update_gold_state(&self.c, stcls) + update_gold_state(&self.c, stcls.c) cdef GoldNERStateC create_gold_state( Pool mem, BiluoPushDown moves, - StateClass stcls, + const StateC* stcls, Example example ) except *: cdef GoldNERStateC gs @@ -67,7 +70,7 @@ cdef GoldNERStateC create_gold_state( return gs -cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *: +cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *: # We don't need to update each time, unlike the parser. pass @@ -75,14 +78,15 @@ cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *: cdef do_func_t[N_MOVES] do_funcs -cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil: - if not st.entity_is_open(): +cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil: + if not state.entity_is_open(): return False - cdef const Transition* gold = &golds[st.E(0)] + cdef const Transition* gold = &golds[state.E(0)] + ent = state.get_ent() if gold.move != BEGIN and gold.move != UNIT: return True - elif gold.label != st.E_(0).ent_type: + elif gold.label != ent.label: return True else: return False @@ -228,15 +232,18 @@ cdef class BiluoPushDown(TransitionSystem): self.labels[action][label_name] = -1 return 1 - cdef int initialize_state(self, StateC* st) nogil: - # This is especially necessary when we use limited training data. - for i in range(st.length): - if st._sent[i].ent_type != 0: - with gil: - self.add_action(BEGIN, st._sent[i].ent_type) - self.add_action(IN, st._sent[i].ent_type) - self.add_action(UNIT, st._sent[i].ent_type) - self.add_action(LAST, st._sent[i].ent_type) + def set_annotations(self, StateClass state, Doc doc): + cdef int i + ents = [] + for i in range(state.c._ents.size()): + ent = state.c._ents.at(i) + if ent.start != -1 and ent.end != -1: + ents.append(Span(doc, ent.start, ent.end, label=ent.label)) + doc.set_ents(ents, default="unmodified") + # Set non-blocked tokens to O + for i in range(doc.length): + if doc.c[i].ent_iob == 0: + doc.c[i].ent_iob = 2 def init_gold(self, StateClass state, Example example): return BiluoGold(self, state, example) @@ -255,26 +262,25 @@ cdef class BiluoPushDown(TransitionSystem): gold_state = gold_.c n_gold = 0 if self.c[i].is_valid(stcls.c, self.c[i].label): - cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label) else: cost = 9000 return cost cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass stcls, gold) except -1: + const StateC* state, gold) except -1: if not isinstance(gold, BiluoGold): raise TypeError(Errors.E909.format(name="BiluoGold")) cdef BiluoGold gold_ = gold - gold_.update(stcls) gold_state = gold_.c + update_gold_state(&gold_state, state) n_gold = 0 + self.set_valid(is_valid, state) for i in range(self.n_moves): - if self.c[i].is_valid(stcls.c, self.c[i].label): - is_valid[i] = 1 - costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + if is_valid[i]: + costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label) n_gold += costs[i] <= 0 else: - is_valid[i] = 0 costs[i] = 9000 if n_gold < 1: raise ValueError @@ -290,7 +296,7 @@ cdef class Missing: pass @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: return 9000 @@ -299,10 +305,10 @@ cdef class Begin: cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type - # If we're the last token of the input, we can't B -- must U or O. - if st.B(1) == -1: + if st.entity_is_open(): return False - elif st.entity_is_open(): + if st.buffer_length() < 2: + # If we're the last token of the input, we can't B -- must U or O. return False elif label == 0: return False @@ -337,12 +343,11 @@ cdef class Begin: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.open_ent(label) - st.set_ent_tag(st.B(0), 3, label) st.push() st.pop() @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label @@ -366,16 +371,17 @@ cdef class Begin: cdef class In: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: + if not st.entity_is_open(): + return False + if st.buffer_length() < 2: + # If we're at the end, we can't I. + return False + ent = st.get_ent() cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: return False - elif st.E_(0).ent_type != label: - return False - elif not st.entity_is_open(): - return False - elif st.B(1) == -1: - # If we're at the end, we can't I. + elif ent.label != label: return False elif preset_ent_iob == 3: return False @@ -401,12 +407,11 @@ cdef class In: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - st.set_ent_tag(st.B(0), 1, label) st.push() st.pop() @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold move = IN cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT @@ -457,7 +462,7 @@ cdef class Last: # Otherwise, force acceptance, even if we're across a sentence # boundary or the token is whitespace. return True - elif st.E_(0).ent_type != label: + elif st.get_ent().label != label: return False elif st.B_(1).ent_iob == 1: # If a preset entity has I next, we can't L here. @@ -468,12 +473,11 @@ cdef class Last: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.close_ent() - st.set_ent_tag(st.B(0), 1, label) st.push() st.pop() @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold move = LAST @@ -537,12 +541,11 @@ cdef class Unit: cdef int transition(StateC* st, attr_t label) nogil: st.open_ent(label) st.close_ent() - st.set_ent_tag(st.B(0), 3, label) st.push() st.pop() @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label @@ -578,12 +581,11 @@ cdef class Out: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - st.set_ent_tag(st.B(0), 2, 0) st.push() st.pop() @staticmethod - cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label diff --git a/spacy/pipeline/_parser_internals/stateclass.pxd b/spacy/pipeline/_parser_internals/stateclass.pxd index 1d9f05538..54ff344b9 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pxd +++ b/spacy/pipeline/_parser_internals/stateclass.pxd @@ -2,30 +2,24 @@ from cymem.cymem cimport Pool from ...structs cimport TokenC, SpanC from ...typedefs cimport attr_t +from ...tokens.doc cimport Doc from ._state cimport StateC cdef class StateClass: - cdef Pool mem cdef StateC* c + cdef readonly Doc doc cdef int _borrowed @staticmethod - cdef inline StateClass init(const TokenC* sent, int length): + cdef inline StateClass borrow(StateC* ptr, Doc doc): cdef StateClass self = StateClass() - self.c = new StateC(sent, length) - return self - - @staticmethod - cdef inline StateClass borrow(StateC* ptr): - cdef StateClass self = StateClass() - del self.c self.c = ptr self._borrowed = 1 + self.doc = doc return self - @staticmethod cdef inline StateClass init_offset(const TokenC* sent, int length, int offset): @@ -33,105 +27,3 @@ cdef class StateClass: self.c = new StateC(sent, length) self.c.offset = offset return self - - cdef inline int S(self, int i) nogil: - return self.c.S(i) - - cdef inline int B(self, int i) nogil: - return self.c.B(i) - - cdef inline const TokenC* S_(self, int i) nogil: - return self.c.S_(i) - - cdef inline const TokenC* B_(self, int i) nogil: - return self.c.B_(i) - - cdef inline const TokenC* H_(self, int i) nogil: - return self.c.H_(i) - - cdef inline const TokenC* E_(self, int i) nogil: - return self.c.E_(i) - - cdef inline const TokenC* L_(self, int i, int idx) nogil: - return self.c.L_(i, idx) - - cdef inline const TokenC* R_(self, int i, int idx) nogil: - return self.c.R_(i, idx) - - cdef inline const TokenC* safe_get(self, int i) nogil: - return self.c.safe_get(i) - - cdef inline int H(self, int i) nogil: - return self.c.H(i) - - cdef inline int E(self, int i) nogil: - return self.c.E(i) - - cdef inline int L(self, int i, int idx) nogil: - return self.c.L(i, idx) - - cdef inline int R(self, int i, int idx) nogil: - return self.c.R(i, idx) - - cdef inline bint empty(self) nogil: - return self.c.empty() - - cdef inline bint eol(self) nogil: - return self.c.eol() - - cdef inline bint at_break(self) nogil: - return self.c.at_break() - - cdef inline bint has_head(self, int i) nogil: - return self.c.has_head(i) - - cdef inline int n_L(self, int i) nogil: - return self.c.n_L(i) - - cdef inline int n_R(self, int i) nogil: - return self.c.n_R(i) - - cdef inline bint stack_is_connected(self) nogil: - return False - - cdef inline bint entity_is_open(self) nogil: - return self.c.entity_is_open() - - cdef inline int stack_depth(self) nogil: - return self.c.stack_depth() - - cdef inline int buffer_length(self) nogil: - return self.c.buffer_length() - - cdef inline void push(self) nogil: - self.c.push() - - cdef inline void pop(self) nogil: - self.c.pop() - - cdef inline void unshift(self) nogil: - self.c.unshift() - - cdef inline void add_arc(self, int head, int child, attr_t label) nogil: - self.c.add_arc(head, child, label) - - cdef inline void del_arc(self, int head, int child) nogil: - self.c.del_arc(head, child) - - cdef inline void open_ent(self, attr_t label) nogil: - self.c.open_ent(label) - - cdef inline void close_ent(self) nogil: - self.c.close_ent() - - cdef inline void set_ent_tag(self, int i, int ent_iob, attr_t ent_type) nogil: - self.c.set_ent_tag(i, ent_iob, ent_type) - - cdef inline void set_break(self, int i) nogil: - self.c.set_break(i) - - cdef inline void clone(self, StateClass src) nogil: - self.c.clone(src.c) - - cdef inline void fast_forward(self) nogil: - self.c.fast_forward() diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index 880cf6cc5..4eaddd997 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -1,17 +1,20 @@ # cython: infer_types=True import numpy +from libcpp.vector cimport vector +from ._state cimport ArcC from ...tokens.doc cimport Doc cdef class StateClass: def __init__(self, Doc doc=None, int offset=0): - cdef Pool mem = Pool() - self.mem = mem self._borrowed = 0 if doc is not None: self.c = new StateC(doc.c, doc.length) self.c.offset = offset + self.doc = doc + else: + self.doc = None def __dealloc__(self): if self._borrowed != 1: @@ -19,36 +22,157 @@ cdef class StateClass: @property def stack(self): - return {self.S(i) for i in range(self.c._s_i)} + return [self.S(i) for i in range(self.c.stack_depth())] @property def queue(self): - return {self.B(i) for i in range(self.c.buffer_length())} + return [self.B(i) for i in range(self.c.buffer_length())] @property def token_vector_lenth(self): return self.doc.tensor.shape[1] @property - def history(self): - hist = numpy.ndarray((8,), dtype='i') - for i in range(8): - hist[i] = self.c.get_hist(i+1) - return hist + def arcs(self): + cdef vector[ArcC] arcs + self.c.get_arcs(&arcs) + return list(arcs) + #py_arcs = [] + #for arc in arcs: + # if arc.head != -1 and arc.child != -1: + # py_arcs.append((arc.head, arc.child, arc.label)) + #return arcs + + def add_arc(self, int head, int child, int label): + self.c.add_arc(head, child, label) + + def del_arc(self, int head, int child): + self.c.del_arc(head, child) + + def H(self, int child): + return self.c.H(child) + + def L(self, int head, int idx): + return self.c.L(head, idx) + + def R(self, int head, int idx): + return self.c.R(head, idx) + + @property + def _b_i(self): + return self.c._b_i + + @property + def length(self): + return self.c.length def is_final(self): return self.c.is_final() def copy(self): - cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length) + cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset) new_state.c.clone(self.c) return new_state - def print_state(self, words): + def print_state(self): + words = [token.text for token in self.doc] words = list(words) + ['_'] - top = f"{words[self.S(0)]}_{self.S_(0).head}" - second = f"{words[self.S(1)]}_{self.S_(1).head}" - third = f"{words[self.S(2)]}_{self.S_(2).head}" - n0 = words[self.B(0)] - n1 = words[self.B(1)] - return ' '.join((third, second, top, '|', n0, n1)) + bools = ["F", "T"] + sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))] + shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)] + shifted.append("") + sent_starts.append("") + top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}" + second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}" + third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}" + n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}" + n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}" + return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1)) + + def S(self, int i): + return self.c.S(i) + + def B(self, int i): + return self.c.B(i) + + def H(self, int i): + return self.c.H(i) + + def E(self, int i): + return self.c.E(i) + + def L(self, int i, int idx): + return self.c.L(i, idx) + + def R(self, int i, int idx): + return self.c.R(i, idx) + + def S_(self, int i): + return self.doc[self.c.S(i)] + + def B_(self, int i): + return self.doc[self.c.B(i)] + + def H_(self, int i): + return self.doc[self.c.H(i)] + + def E_(self, int i): + return self.doc[self.c.E(i)] + + def L_(self, int i, int idx): + return self.doc[self.c.L(i, idx)] + + def R_(self, int i, int idx): + return self.doc[self.c.R(i, idx)] + + def empty(self): + return self.c.empty() + + def eol(self): + return self.c.eol() + + def at_break(self): + return False + #return self.c.at_break() + + def has_head(self, int i): + return self.c.has_head(i) + + def n_L(self, int i): + return self.c.n_L(i) + + def n_R(self, int i): + return self.c.n_R(i) + + def entity_is_open(self): + return self.c.entity_is_open() + + def stack_depth(self): + return self.c.stack_depth() + + def buffer_length(self): + return self.c.buffer_length() + + def push(self): + self.c.push() + + def pop(self): + self.c.pop() + + def unshift(self): + self.c.unshift() + + def add_arc(self, int head, int child, attr_t label): + self.c.add_arc(head, child, label) + + def del_arc(self, int head, int child): + self.c.del_arc(head, child) + + def open_ent(self, attr_t label): + self.c.open_ent(label) + + def close_ent(self): + self.c.close_ent() + + def clone(self, StateClass src): + self.c.clone(src.c) diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd index 458f1d5f9..eed347b98 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pxd +++ b/spacy/pipeline/_parser_internals/transition_system.pxd @@ -16,14 +16,14 @@ cdef struct Transition: weight_t score bint (*is_valid)(const StateC* state, attr_t label) nogil - weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil + weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil int (*do)(StateC* state, attr_t label) nogil -ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold, +ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold, attr_tlabel) nogil -ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil -ctypedef weight_t (*label_cost_func_t)(StateClass state, const void* +ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil +ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void* gold, attr_t label) nogil ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil @@ -41,9 +41,8 @@ cdef class TransitionSystem: cdef public attr_t root_label cdef public freqs cdef public object labels - - cdef int initialize_state(self, StateC* state) nogil - cdef int finalize_state(self, StateC* state) nogil + cdef init_state_t init_beam_state + cdef del_state_t del_beam_state cdef Transition lookup_transition(self, object name) except * @@ -52,4 +51,4 @@ cdef class TransitionSystem: cdef int set_valid(self, int* output, const StateC* st) nogil cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass state, gold) except -1 + const StateC* state, gold) except -1 diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 7694e7f34..9bb4f7f5f 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -5,6 +5,7 @@ from cymem.cymem cimport Pool from collections import Counter import srsly +from . cimport _beam_utils from ...typedefs cimport weight_t, attr_t from ...tokens.doc cimport Doc from ...structs cimport TokenC @@ -44,6 +45,8 @@ cdef class TransitionSystem: if labels_by_action: self.initialize_actions(labels_by_action, min_freq=min_freq) self.root_label = self.strings.add('ROOT') + self.init_beam_state = _init_state + self.del_beam_state = _del_state def __reduce__(self): return (self.__class__, (self.strings, self.labels), None, None) @@ -54,7 +57,6 @@ cdef class TransitionSystem: offset = 0 for doc in docs: state = StateClass(doc, offset=offset) - self.initialize_state(state.c) states.append(state) offset += len(doc) return states @@ -80,7 +82,7 @@ cdef class TransitionSystem: history = [] debug_log = [] while not state.is_final(): - self.set_costs(is_valid, costs, state, gold) + self.set_costs(is_valid, costs, state.c, gold) for i in range(self.n_moves): if is_valid[i] and costs[i] <= 0: action = self.c[i] @@ -124,15 +126,6 @@ cdef class TransitionSystem: action = self.lookup_transition(name) action.do(state.c, action.label) - cdef int initialize_state(self, StateC* state) nogil: - pass - - cdef int finalize_state(self, StateC* state) nogil: - pass - - def finalize_doc(self, doc): - pass - cdef Transition lookup_transition(self, object name) except *: raise NotImplementedError @@ -151,7 +144,7 @@ cdef class TransitionSystem: is_valid[i] = self.c[i].is_valid(st, self.c[i].label) cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass stcls, gold) except -1: + const StateC* state, gold) except -1: raise NotImplementedError def get_class_name(self, int clas): diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx index a9dcd705e..724eb6cd1 100644 --- a/spacy/pipeline/dep_parser.pyx +++ b/spacy/pipeline/dep_parser.pyx @@ -105,6 +105,93 @@ def make_parser( update_with_oracle_cut_size=update_with_oracle_cut_size, multitasks=[], learn_tokens=learn_tokens, + min_action_freq=min_action_freq, + beam_width=1, + beam_density=0.0, + beam_update_prob=0.0, + ) + +@Language.factory( + "beam_parser", + assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], + default_config={ + "beam_width": 8, + "beam_density": 0.01, + "beam_update_prob": 0.5, + "moves": None, + "update_with_oracle_cut_size": 100, + "learn_tokens": False, + "min_action_freq": 30, + "model": DEFAULT_PARSER_MODEL, + }, + default_score_weights={ + "dep_uas": 0.5, + "dep_las": 0.5, + "dep_las_per_type": None, + "sents_p": None, + "sents_r": None, + "sents_f": 0.0, + }, +) +def make_beam_parser( + nlp: Language, + name: str, + model: Model, + moves: Optional[list], + update_with_oracle_cut_size: int, + learn_tokens: bool, + min_action_freq: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, +): + """Create a transition-based DependencyParser component that uses beam-search. + The dependency parser jointly learns sentence segmentation and labelled + dependency parsing, and can optionally learn to merge tokens that had been + over-segmented by the tokenizer. + + The parser uses a variant of the non-monotonic arc-eager transition-system + described by Honnibal and Johnson (2014), with the addition of a "break" + transition to perform the sentence segmentation. Nivre's pseudo-projective + dependency transformation is used to allow the parser to predict + non-projective parses. + + The parser is trained using a global objective. That is, it learns to assign + probabilities to whole parses. + + model (Model): The model for the transition-based parser. The model needs + to have a specific substructure of named components --- see the + spacy.ml.tb_framework.TransitionModel for details. + moves (List[str]): A list of transition names. Inferred from the data if not + provided. + beam_width (int): The number of candidate analyses to maintain. + beam_density (float): The minimum ratio between the scores of the first and + last candidates in the beam. This allows the parser to avoid exploring + candidates that are too far behind. This is mostly intended to improve + efficiency, but it can also improve accuracy as deeper search is not + always better. + beam_update_prob (float): The chance of making a beam update, instead of a + greedy update. Greedy updates are an approximation for the beam updates, + and are faster to compute. + learn_tokens (bool): Whether to learn to merge subtokens that are split + relative to the gold standard. Experimental. + min_action_freq (int): The minimum frequency of labelled actions to retain. + Rarer labelled actions have their label backed-off to "dep". While this + primarily affects the label accuracy, it can also affect the attachment + structure, as the labels are used to represent the pseudo-projectivity + transformation. + """ + return DependencyParser( + nlp.vocab, + model, + name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, + multitasks=[], + learn_tokens=learn_tokens, min_action_freq=min_action_freq ) diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index 0f93b43ac..e748d95fd 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -82,6 +82,79 @@ def make_ner( multitasks=[], min_action_freq=1, learn_tokens=False, + beam_width=1, + beam_density=0.0, + beam_update_prob=0.0, + ) + +@Language.factory( + "beam_ner", + assigns=["doc.ents", "token.ent_iob", "token.ent_type"], + default_config={ + "moves": None, + "update_with_oracle_cut_size": 100, + "model": DEFAULT_NER_MODEL, + "beam_density": 0.01, + "beam_update_prob": 0.5, + "beam_width": 32 + }, + default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, +) +def make_beam_ner( + nlp: Language, + name: str, + model: Model, + moves: Optional[list], + update_with_oracle_cut_size: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, +): + """Create a transition-based EntityRecognizer component that uses beam-search. + The entity recognizer identifies non-overlapping labelled spans of tokens. + + The transition-based algorithm used encodes certain assumptions that are + effective for "traditional" named entity recognition tasks, but may not be + a good fit for every span identification problem. Specifically, the loss + function optimizes for whole entity accuracy, so if your inter-annotator + agreement on boundary tokens is low, the component will likely perform poorly + on your problem. The transition-based algorithm also assumes that the most + decisive information about your entities will be close to their initial tokens. + If your entities are long and characterised by tokens in their middle, the + component will likely do poorly on your task. + + model (Model): The model for the transition-based parser. The model needs + to have a specific substructure of named components --- see the + spacy.ml.tb_framework.TransitionModel for details. + moves (list[str]): A list of transition names. Inferred from the data if not + provided. + update_with_oracle_cut_size (int): + During training, cut long sequences into shorter segments by creating + intermediate states based on the gold-standard history. The model is + not very sensitive to this parameter, so you usually won't need to change + it. 100 is a good default. + beam_width (int): The number of candidate analyses to maintain. + beam_density (float): The minimum ratio between the scores of the first and + last candidates in the beam. This allows the parser to avoid exploring + candidates that are too far behind. This is mostly intended to improve + efficiency, but it can also improve accuracy as deeper search is not + always better. + beam_update_prob (float): The chance of making a beam update, instead of a + greedy update. Greedy updates are an approximation for the beam updates, + and are faster to compute. + """ + return EntityRecognizer( + nlp.vocab, + model, + name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + multitasks=[], + min_action_freq=1, + learn_tokens=False, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, ) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 63a8595cc..8aeacbafb 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -4,13 +4,14 @@ from cymem.cymem cimport Pool cimport numpy as np from itertools import islice from libcpp.vector cimport vector -from libc.string cimport memset +from libc.string cimport memset, memcpy from libc.stdlib cimport calloc, free import random from typing import Optional import srsly -from thinc.api import set_dropout_rate +from thinc.api import set_dropout_rate, CupyOps +from thinc.extra.search cimport Beam import numpy.random import numpy import warnings @@ -22,6 +23,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss from ..ml.parser_model cimport get_c_weights, get_c_sizes from ..tokens.doc cimport Doc from .trainable_pipe import TrainablePipe +from ._parser_internals cimport _beam_utils +from ._parser_internals import _beam_utils from ..training import validate_examples, validate_get_examples from ..errors import Errors, Warnings @@ -41,9 +44,12 @@ cdef class Parser(TrainablePipe): moves=None, *, update_with_oracle_cut_size, - multitasks=tuple(), min_action_freq, learn_tokens, + beam_width=1, + beam_density=0.0, + beam_update_prob=0.0, + multitasks=tuple(), ): """Create a Parser. @@ -61,7 +67,10 @@ cdef class Parser(TrainablePipe): "update_with_oracle_cut_size": update_with_oracle_cut_size, "multitasks": list(multitasks), "min_action_freq": min_action_freq, - "learn_tokens": learn_tokens + "learn_tokens": learn_tokens, + "beam_width": beam_width, + "beam_density": beam_density, + "beam_update_prob": beam_update_prob } if moves is None: # defined by EntityRecognizer as a BiluoPushDown @@ -183,7 +192,15 @@ cdef class Parser(TrainablePipe): result = self.moves.init_batch(docs) self._resize() return result - return self.greedy_parse(docs, drop=0.0) + if self.cfg["beam_width"] == 1: + return self.greedy_parse(docs, drop=0.0) + else: + return self.beam_parse( + docs, + drop=0.0, + beam_width=self.cfg["beam_width"], + beam_density=self.cfg["beam_density"] + ) def greedy_parse(self, docs, drop=0.): cdef vector[StateC*] states @@ -207,6 +224,31 @@ cdef class Parser(TrainablePipe): del model return batch + def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): + cdef Beam beam + cdef Doc doc + batch = _beam_utils.BeamBatch( + self.moves, + self.moves.init_batch(docs), + None, + beam_width, + density=beam_density + ) + # This is pretty dirty, but the NER can resize itself in init_batch, + # if labels are missing. We therefore have to check whether we need to + # expand our model output. + self._resize() + model = self.model.predict(docs) + while not batch.is_done: + states = batch.get_unfinished_states() + if not states: + break + scores = model.predict(states) + batch.advance(scores) + model.clear_memory() + del model + return list(batch) + cdef void _parseC(self, StateC** states, WeightsC weights, SizesC sizes) nogil: cdef int i, j @@ -227,14 +269,13 @@ cdef class Parser(TrainablePipe): unfinished.clear() free_activations(&activations) - def set_annotations(self, docs, states): + def set_annotations(self, docs, states_or_beams): cdef StateClass state + cdef Beam beam cdef Doc doc + states = _beam_utils.collect_states(states_or_beams, docs) for i, (state, doc) in enumerate(zip(states, docs)): - self.moves.finalize_state(state.c) - for j in range(doc.length): - doc.c[j] = state.c._sent[j] - self.moves.finalize_doc(doc) + self.moves.set_annotations(state, doc) for hook in self.postprocesses: hook(doc) @@ -265,7 +306,6 @@ cdef class Parser(TrainablePipe): else: action = self.moves.c[guess] action.do(states[i], action.label) - states[i].push_hist(guess) free(is_valid) def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None): @@ -276,13 +316,23 @@ cdef class Parser(TrainablePipe): validate_examples(examples, "Parser.update") for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) + n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) if n_examples == 0: return losses set_dropout_rate(self.model, drop) - # Prepare the stepwise model, and get the callback for finishing the batch - model, backprop_tok2vec = self.model.begin_update( - [eg.predicted for eg in examples]) + # The probability we use beam update, instead of falling back to + # a greedy update + beam_update_prob = self.cfg["beam_update_prob"] + if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob: + return self.update_beam( + examples, + beam_width=self.cfg["beam_width"], + set_annotations=set_annotations, + sgd=sgd, + losses=losses, + beam_density=self.cfg["beam_density"] + ) max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the @@ -296,6 +346,8 @@ cdef class Parser(TrainablePipe): states, golds, _ = self.moves.init_gold_batch(examples) if not states: return losses + model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples]) + all_states = list(states) states_golds = list(zip(states, golds)) n_moves = 0 @@ -379,6 +431,27 @@ cdef class Parser(TrainablePipe): del tutor return losses + def update_beam(self, examples, *, beam_width, + drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0): + states, golds, _ = self.moves.init_gold_batch(examples) + if not states: + return losses + # Prepare the stepwise model, and get the callback for finishing the batch + model, backprop_tok2vec = self.model.begin_update( + [eg.predicted for eg in examples]) + loss = _beam_utils.update_beam( + self.moves, + states, + golds, + model, + beam_width, + beam_density=beam_density, + ) + losses[self.name] += loss + backprop_tok2vec(golds) + if sgd is not None: + self.finish_update(sgd) + def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): cdef StateClass state cdef Pool mem = Pool() @@ -396,7 +469,7 @@ cdef class Parser(TrainablePipe): for i, (state, gold) in enumerate(zip(states, golds)): memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(costs, 0, self.moves.n_moves * sizeof(float)) - self.moves.set_costs(is_valid, costs, state, gold) + self.moves.set_costs(is_valid, costs, state.c, gold) for j in range(self.moves.n_moves): if costs[j] <= 0.0 and j in unseen_classes: unseen_classes.remove(j) @@ -539,7 +612,6 @@ cdef class Parser(TrainablePipe): for clas in oracle_actions[i:i+max_length]: action = self.moves.c[clas] action.do(state.c, action.label) - state.c.push_hist(action.clas) if state.is_final(): break if self.moves.has_gold(eg, start_state.B(0), state.B(0)): diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index 84070db73..fa78301af 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -7,6 +7,7 @@ from spacy.tokens import Doc from spacy.pipeline._parser_internals.nonproj import projectivize from spacy.pipeline._parser_internals.arc_eager import ArcEager from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL +from spacy.pipeline._parser_internals.stateclass import StateClass def get_sequence_costs(M, words, heads, deps, transitions): @@ -47,15 +48,24 @@ def test_oracle_four_words(arc_eager, vocab): for dep in deps: arc_eager.add_action(2, dep) # Left arc_eager.add_action(3, dep) # Right - actions = ["L-left", "B-ROOT", "L-left"] + actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"] state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + expected_gold = [ + ["S"], + ["B-ROOT", "L-left"], + ["B-ROOT"], + ["S"], + ["D"], + ["S"], + ["L-left"], + ["S"], + ["D"] + ] assert state.is_final() for i, state_costs in enumerate(cost_history): # Check gold moves is 0 cost - assert state_costs[actions[i]] == 0.0, actions[i] - for other_action, cost in state_costs.items(): - if other_action != actions[i]: - assert cost >= 1, (i, other_action) + golds = [act for act, cost in state_costs.items() if cost < 1] + assert golds == expected_gold[i], (i, golds, expected_gold[i]) annot_tuples = [ @@ -169,12 +179,15 @@ def test_oracle_dev_sentence(vocab, arc_eager): . punct said """ expected_transitions = [ + "S", # Shift "Rolls-Royce" "S", # Shift 'Motor' "S", # Shift 'Cars' "L-nn", # Attach 'Cars' to 'Inc.' "L-nn", # Attach 'Motor' to 'Inc.' - "L-nn", # Attach 'Rolls-Royce' to 'Inc.', force shift + "L-nn", # Attach 'Rolls-Royce' to 'Inc.' + "S", # Shift "Inc." "L-nsubj", # Attach 'Inc.' to 'said' + "S", # Shift 'said' "S", # Shift 'it' "L-nsubj", # Attach 'it.' to 'expects' "R-ccomp", # Attach 'expects' to 'said' @@ -204,6 +217,8 @@ def test_oracle_dev_sentence(vocab, arc_eager): "D", # Reduce "steady" "D", # Reduce "expects" "R-punct", # Attach "." to "said" + "D", # Reduce "." + "D", # Reduce "said" ] gold_words = [] @@ -221,10 +236,40 @@ def test_oracle_dev_sentence(vocab, arc_eager): for dep in gold_deps: arc_eager.add_action(2, dep) # Left arc_eager.add_action(3, dep) # Right - doc = Doc(Vocab(), words=gold_words) example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) - - ae_oracle_actions = arc_eager.get_oracle_sequence(example) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] assert ae_oracle_actions == expected_transitions + + +def test_oracle_bad_tokenization(vocab, arc_eager): + words_deps_heads = """ + [catalase] dep is + : punct is + that nsubj is + is root is + bad comp is + """ + + gold_words = [] + gold_deps = [] + gold_heads = [] + for line in words_deps_heads.strip().split("\n"): + line = line.strip() + if not line: + continue + word, dep, head = line.split() + gold_words.append(word) + gold_deps.append(dep) + gold_heads.append(head) + gold_heads = [gold_words.index(head) for head in gold_heads] + for dep in gold_deps: + arc_eager.add_action(2, dep) # Left + arc_eager.add_action(3, dep) # Right + reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads) + predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]) + example = Example(predicted=predicted, reference=reference) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) + ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] + assert ae_oracle_actions diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index b4c22b48d..9ed87329c 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -54,7 +54,7 @@ def tsys(vocab, entity_types): def test_get_oracle_moves(tsys, doc, entity_annots): example = Example.from_dict(doc, {"entities": entity_annots}) - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example, _debug=False) names = [tsys.get_class_name(act) for act in act_classes] assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"] diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py index e69de29bb..1f45b67c8 100644 --- a/spacy/tests/parser/test_nn_beam.py +++ b/spacy/tests/parser/test_nn_beam.py @@ -0,0 +1,144 @@ +# coding: utf8 +from __future__ import unicode_literals + +import pytest +import hypothesis +import hypothesis.strategies +import numpy +from spacy.vocab import Vocab +from spacy.language import Language +from spacy.pipeline import DependencyParser +from spacy.pipeline._parser_internals.arc_eager import ArcEager +from spacy.tokens import Doc +from spacy.pipeline._parser_internals._beam_utils import BeamBatch +from spacy.pipeline._parser_internals.stateclass import StateClass +from spacy.training import Example +from thinc.tests.strategies import ndarrays_of_shape + + +@pytest.fixture(scope="module") +def vocab(): + return Vocab() + + +@pytest.fixture(scope="module") +def moves(vocab): + aeager = ArcEager(vocab.strings, {}) + aeager.add_action(0, "") + aeager.add_action(1, "") + aeager.add_action(2, "nsubj") + aeager.add_action(2, "punct") + aeager.add_action(2, "aux") + aeager.add_action(2, "nsubjpass") + aeager.add_action(3, "dobj") + aeager.add_action(2, "aux") + aeager.add_action(4, "ROOT") + return aeager + + +@pytest.fixture(scope="module") +def docs(vocab): + return [ + Doc( + vocab, + words=["Rats", "bite", "things"], + heads=[1, 1, 1], + deps=["nsubj", "ROOT", "dobj"], + sent_starts=[True, False, False] + ) + ] + + +@pytest.fixture(scope="module") +def examples(docs): + return [Example(doc, doc.copy()) for doc in docs] + + +@pytest.fixture +def states(docs): + return [StateClass(doc) for doc in docs] + + +@pytest.fixture +def tokvecs(docs, vector_size): + output = [] + for doc in docs: + vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size)) + output.append(numpy.asarray(vec)) + return output + + +@pytest.fixture(scope="module") +def batch_size(docs): + return len(docs) + + +@pytest.fixture(scope="module") +def beam_width(): + return 4 + +@pytest.fixture(params=[0.0, 0.5, 1.0]) +def beam_density(request): + return request.param + +@pytest.fixture +def vector_size(): + return 6 + + +@pytest.fixture +def beam(moves, examples, beam_width): + states, golds, _ = moves.init_gold_batch(examples) + return BeamBatch(moves, states, golds, width=beam_width, density=0.0) + + +@pytest.fixture +def scores(moves, batch_size, beam_width): + return numpy.asarray( + numpy.concatenate( + [ + numpy.random.uniform(-0.1, 0.1, (beam_width, moves.n_moves)) + for _ in range(batch_size) + ] + ), dtype="float32") + + +def test_create_beam(beam): + pass + + +def test_beam_advance(beam, scores): + beam.advance(scores) + + +def test_beam_advance_too_few_scores(beam, scores): + n_state = sum(len(beam) for beam in beam) + scores = scores[:n_state] + with pytest.raises(IndexError): + beam.advance(scores[:-1]) + + +def test_beam_parse(examples, beam_width): + nlp = Language() + parser = nlp.add_pipe("beam_parser") + parser.cfg["beam_width"] = beam_width + parser.add_label("nsubj") + parser.initialize(lambda: examples) + doc = nlp.make_doc("Australia is a country") + parser(doc) + + + + +@hypothesis.given(hyp=hypothesis.strategies.data()) +def test_beam_density(moves, examples, beam_width, hyp): + beam_density = float(hyp.draw(hypothesis.strategies.floats(0.0, 1.0, width=32))) + states, golds, _ = moves.init_gold_batch(examples) + beam = BeamBatch(moves, states, golds, width=beam_width, density=beam_density) + n_state = sum(len(beam) for beam in beam) + scores = hyp.draw(ndarrays_of_shape((n_state, moves.n_moves))) + beam.advance(scores) + for b in beam: + beam_probs = b.probs + assert b.min_density == beam_density + assert beam_probs[-1] >= beam_probs[0] * beam_density diff --git a/spacy/tests/parser/test_preset_sbd.py b/spacy/tests/parser/test_preset_sbd.py index ab58ac17b..595bfa537 100644 --- a/spacy/tests/parser/test_preset_sbd.py +++ b/spacy/tests/parser/test_preset_sbd.py @@ -22,6 +22,7 @@ def _parser_example(parser): @pytest.fixture def parser(vocab): + vocab.strings.add("ROOT") config = { "learn_tokens": False, "min_action_freq": 30, @@ -76,13 +77,16 @@ def test_sents_1_2(parser): def test_sents_1_3(parser): doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) - doc[1].sent_start = True - doc[3].sent_start = True + doc[0].is_sent_start = True + doc[1].is_sent_start = True + doc[2].is_sent_start = None + doc[3].is_sent_start = True doc = parser(doc) assert len(list(doc.sents)) >= 3 doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) - doc[1].sent_start = True - doc[2].sent_start = False - doc[3].sent_start = True + doc[0].is_sent_start = True + doc[1].is_sent_start = True + doc[2].is_sent_start = False + doc[3].is_sent_start = True doc = parser(doc) assert len(list(doc.sents)) == 3 diff --git a/spacy/tests/parser/test_state.py b/spacy/tests/parser/test_state.py new file mode 100644 index 000000000..7cd4b98e1 --- /dev/null +++ b/spacy/tests/parser/test_state.py @@ -0,0 +1,74 @@ +import pytest + +from spacy.tokens.doc import Doc +from spacy.vocab import Vocab +from spacy.pipeline._parser_internals.stateclass import StateClass + +@pytest.fixture +def vocab(): + return Vocab() + +@pytest.fixture +def doc(vocab): + return Doc(vocab, words=["a", "b", "c", "d"]) + +def test_init_state(doc): + state = StateClass(doc) + assert state.stack == [] + assert state.queue == list(range(len(doc))) + assert not state.is_final() + assert state.buffer_length() == 4 + +def test_push_pop(doc): + state = StateClass(doc) + state.push() + assert state.buffer_length() == 3 + assert state.stack == [0] + assert 0 not in state.queue + state.push() + assert state.stack == [1, 0] + assert 1 not in state.queue + assert state.buffer_length() == 2 + state.pop() + assert state.stack == [0] + assert 1 not in state.queue + +def test_stack_depth(doc): + state = StateClass(doc) + assert state.stack_depth() == 0 + assert state.buffer_length() == len(doc) + state.push() + assert state.buffer_length() == 3 + assert state.stack_depth() == 1 + + +def test_H(doc): + state = StateClass(doc) + assert state.H(0) == -1 + state.add_arc(1, 0, 0) + assert state.arcs == [{"head": 1, "child": 0, "label": 0}] + assert state.H(0) == 1 + state.add_arc(3, 1, 0) + assert state.H(1) == 3 + + +def test_L(doc): + state = StateClass(doc) + assert state.L(2, 1) == -1 + state.add_arc(2, 1, 0) + assert state.arcs == [{"head": 2, "child": 1, "label": 0}] + assert state.L(2, 1) == 1 + state.add_arc(2, 0, 0) + assert state.L(2, 1) == 0 + assert state.n_L(2) == 2 + + +def test_R(doc): + state = StateClass(doc) + assert state.R(0, 1) == -1 + state.add_arc(0, 1, 0) + assert state.arcs == [{"head": 0, "child": 1, "label": 0}] + assert state.R(0, 1) == 1 + state.add_arc(0, 2, 0) + assert state.R(0, 1) == 2 + assert state.n_R(0) == 2 diff --git a/spacy/tests/regression/test_issue4001-4500.py b/spacy/tests/regression/test_issue4001-4500.py index 73aea5b4b..873ef9c1d 100644 --- a/spacy/tests/regression/test_issue4001-4500.py +++ b/spacy/tests/regression/test_issue4001-4500.py @@ -122,7 +122,8 @@ def test_issue4042_bug2(): assert "SOME_LABEL" in ner1.labels apple_ent = Span(doc1, 5, 6, label="MY_ORG") doc1.ents = list(doc1.ents) + [apple_ent] - # reapply the NER - at this point it should resize itself + # Add the label explicitly. Previously we didn't require this. + ner1.add_label("MY_ORG") ner1(doc1) assert len(ner1.labels) == 2 assert "SOME_LABEL" in ner1.labels diff --git a/spacy/tests/serialize/test_serialize_pipeline.py b/spacy/tests/serialize/test_serialize_pipeline.py index 951dd3035..2deaa180d 100644 --- a/spacy/tests/serialize/test_serialize_pipeline.py +++ b/spacy/tests/serialize/test_serialize_pipeline.py @@ -22,6 +22,9 @@ def parser(en_vocab): "learn_tokens": False, "min_action_freq": 30, "update_with_oracle_cut_size": 100, + "beam_width": 1, + "beam_update_prob": 1.0, + "beam_density": 0.0 } cfg = {"model": DEFAULT_PARSER_MODEL} model = registry.resolve(cfg, validate=True)["model"] @@ -36,6 +39,9 @@ def blank_parser(en_vocab): "learn_tokens": False, "min_action_freq": 30, "update_with_oracle_cut_size": 100, + "beam_width": 1, + "beam_update_prob": 1.0, + "beam_density": 0.0 } cfg = {"model": DEFAULT_PARSER_MODEL} model = registry.resolve(cfg, validate=True)["model"] @@ -58,6 +64,9 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser): "learn_tokens": False, "min_action_freq": 0, "update_with_oracle_cut_size": 100, + "beam_width": 1, + "beam_update_prob": 1.0, + "beam_density": 0.0 } cfg = {"model": DEFAULT_PARSER_MODEL} model = registry.resolve(cfg, validate=True)["model"] @@ -79,6 +88,9 @@ def test_serialize_parser_strings(Parser): "learn_tokens": False, "min_action_freq": 0, "update_with_oracle_cut_size": 100, + "beam_width": 1, + "beam_update_prob": 1.0, + "beam_density": 0.0 } cfg = {"model": DEFAULT_PARSER_MODEL} model = registry.resolve(cfg, validate=True)["model"] @@ -98,6 +110,9 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser): "learn_tokens": False, "min_action_freq": 0, "update_with_oracle_cut_size": 100, + "beam_width": 1, + "beam_update_prob": 1.0, + "beam_density": 0.0 } cfg = {"model": DEFAULT_PARSER_MODEL} model = registry.resolve(cfg, validate=True)["model"] diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 6a556b5e7..21907e7dd 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -191,6 +191,24 @@ cdef class Example: aligned_deps[cand_i] = deps[gold_i] return aligned_heads, aligned_deps + def get_aligned_sent_starts(self): + """Get list of SENT_START attributes aligned to the predicted tokenization. + If the reference has not sentence starts, return a list of None values. + + The aligned sentence starts use the get_aligned_spans method, rather + than aligning the list of tags, so that it handles cases where a mistaken + tokenization starts the sentence. + """ + if self.y.has_annotation("SENT_START"): + align = self.alignment.y2x + sent_starts = [False] * len(self.x) + for y_sent in self.y.sents: + x_start = int(align[y_sent.start].dataXd[0]) + sent_starts[x_start] = True + return sent_starts + else: + return [None] * len(self.x) + def get_aligned_spans_x2y(self, x_spans): return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)