Merge pull request #10048 from danieldk/index-arcs-by-head

Use constant-time head lookups in StateC::{L,R}
This commit is contained in:
Daniël de Kok 2022-01-20 13:06:14 +01:00 committed by GitHub
commit 6984f55277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 70 additions and 50 deletions

View File

@ -3,6 +3,7 @@ from libc.string cimport memcpy, memset
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
cimport libcpp cimport libcpp
from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp.set cimport set from libcpp.set cimport set
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
@ -30,8 +31,8 @@ cdef cppclass StateC:
vector[int] _stack vector[int] _stack
vector[int] _rebuffer vector[int] _rebuffer
vector[SpanC] _ents vector[SpanC] _ents
vector[ArcC] _left_arcs unordered_map[int, vector[ArcC]] _left_arcs
vector[ArcC] _right_arcs unordered_map[int, vector[ArcC]] _right_arcs
vector[libcpp.bool] _unshiftable vector[libcpp.bool] _unshiftable
set[int] _sent_starts set[int] _sent_starts
TokenC _empty_token TokenC _empty_token
@ -160,15 +161,22 @@ cdef cppclass StateC:
else: else:
return &this._sent[i] return &this._sent[i]
void get_arcs(vector[ArcC]* arcs) nogil const: void map_get_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, vector[ArcC]* out) nogil const:
for i in range(this._left_arcs.size()): cdef const vector[ArcC]* arcs
arc = this._left_arcs.at(i) head_arcs_it = heads_arcs.const_begin()
if arc.head != -1 and arc.child != -1: while head_arcs_it != heads_arcs.const_end():
arcs.push_back(arc) arcs = &deref(head_arcs_it).second
for i in range(this._right_arcs.size()): arcs_it = arcs.const_begin()
arc = this._right_arcs.at(i) while arcs_it != arcs.const_end():
if arc.head != -1 and arc.child != -1: arc = deref(arcs_it)
arcs.push_back(arc) if arc.head != -1 and arc.child != -1:
out.push_back(arc)
incr(arcs_it)
incr(head_arcs_it)
void get_arcs(vector[ArcC]* out) nogil const:
this.map_get_arcs(this._left_arcs, out)
this.map_get_arcs(this._right_arcs, out)
int H(int child) nogil const: int H(int child) nogil const:
if child >= this.length or child < 0: if child >= this.length or child < 0:
@ -182,37 +190,35 @@ cdef cppclass StateC:
else: else:
return this._ents.back().start return this._ents.back().start
int L(int head, int idx) nogil const: int nth_child(const unordered_map[int, vector[ArcC]]& heads_arcs, int head, int idx) nogil const:
if idx < 1 or this._left_arcs.size() == 0: if idx < 1:
return -1 return -1
# Work backwards through left-arcs to find the arc at the head_arcs_it = heads_arcs.const_find(head)
if head_arcs_it == heads_arcs.const_end():
return -1
cdef const vector[ArcC]* arcs = &deref(head_arcs_it).second
# Work backwards through arcs to find the arc at the
# requested index more quickly. # requested index more quickly.
cdef size_t child_index = 0 cdef size_t child_index = 0
it = this._left_arcs.const_rbegin() arcs_it = arcs.const_rbegin()
while it != this._left_arcs.rend(): while arcs_it != arcs.const_rend() and child_index != idx:
arc = deref(it) arc = deref(arcs_it)
if arc.head == head and arc.child != -1 and arc.child < head: if arc.child != -1:
child_index += 1 child_index += 1
if child_index == idx: if child_index == idx:
return arc.child return arc.child
incr(it) incr(arcs_it)
return -1 return -1
int L(int head, int idx) nogil const:
return this.nth_child(this._left_arcs, head, idx)
int R(int head, int idx) nogil const: int R(int head, int idx) nogil const:
if idx < 1 or this._right_arcs.size() == 0: return this.nth_child(this._right_arcs, head, idx)
return -1
cdef vector[int] rights
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > head:
rights.push_back(arc.child)
idx = (<int>rights.size()) - idx
if idx < 0:
return -1
else:
return rights.at(idx)
bint empty() nogil const: bint empty() nogil const:
return this._stack.size() == 0 return this._stack.size() == 0
@ -253,22 +259,29 @@ cdef cppclass StateC:
int r_edge(int word) nogil const: int r_edge(int word) nogil const:
return word return word
int n_L(int head) nogil const: int n_arcs(const unordered_map[int, vector[ArcC]] &heads_arcs, int head) nogil const:
cdef int n = 0 cdef int n = 0
for i in range(this._left_arcs.size()): head_arcs_it = heads_arcs.const_find(head)
arc = this._left_arcs.at(i) if head_arcs_it == heads_arcs.const_end():
if arc.head == head and arc.child != -1 and arc.child < arc.head: return n
cdef const vector[ArcC]* arcs = &deref(head_arcs_it).second
arcs_it = arcs.const_begin()
while arcs_it != arcs.end():
arc = deref(arcs_it)
if arc.child != -1:
n += 1 n += 1
incr(arcs_it)
return n return n
int n_L(int head) nogil const:
return n_arcs(this._left_arcs, head)
int n_R(int head) nogil const: int n_R(int head) nogil const:
cdef int n = 0 return n_arcs(this._right_arcs, head)
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > arc.head:
n += 1
return n
bint stack_is_connected() nogil const: bint stack_is_connected() nogil const:
return False return False
@ -328,19 +341,20 @@ cdef cppclass StateC:
arc.child = child arc.child = child
arc.label = label arc.label = label
if head > child: if head > child:
this._left_arcs.push_back(arc) this._left_arcs[arc.head].push_back(arc)
else: else:
this._right_arcs.push_back(arc) this._right_arcs[arc.head].push_back(arc)
this._heads[child] = head this._heads[child] = head
void del_arc(int h_i, int c_i) nogil: void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
cdef vector[ArcC]* arcs arcs_it = heads_arcs.find(h_i)
if h_i > c_i: if arcs_it == heads_arcs.end():
arcs = &this._left_arcs return
else:
arcs = &this._right_arcs arcs = &deref(arcs_it).second
if arcs.size() == 0: if arcs.size() == 0:
return return
arc = arcs.back() arc = arcs.back()
if arc.head == h_i and arc.child == c_i: if arc.head == h_i and arc.child == c_i:
arcs.pop_back() arcs.pop_back()
@ -353,6 +367,12 @@ cdef cppclass StateC:
arc.label = 0 arc.label = 0
break break
void del_arc(int h_i, int c_i) nogil:
if h_i > c_i:
this.map_del_arc(&this._left_arcs, h_i, c_i)
else:
this.map_del_arc(&this._right_arcs, h_i, c_i)
SpanC get_ent() nogil const: SpanC get_ent() nogil const:
cdef SpanC ent cdef SpanC ent
if this._ents.size() == 0: if this._ents.size() == 0: