mirror of https://github.com/explosion/spaCy.git
Fix many-to-one alignment
This commit is contained in:
parent
4890ee1732
commit
6138439469
|
@ -90,7 +90,7 @@ from .compat import unicode_
|
||||||
from murmurhash.mrmr cimport hash32
|
from murmurhash.mrmr cimport hash32
|
||||||
|
|
||||||
|
|
||||||
def align(S, T, many_to_one=False, one_to_many=False):
|
def align(S, T):
|
||||||
cdef int m = len(S)
|
cdef int m = len(S)
|
||||||
cdef int n = len(T)
|
cdef int n = len(T)
|
||||||
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
|
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
|
||||||
|
@ -126,39 +126,58 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths):
|
||||||
i2j_multi: {1: 1, 2: 1}
|
i2j_multi: {1: 1, 2: 1}
|
||||||
j2i_multi: {}
|
j2i_multi: {}
|
||||||
'''
|
'''
|
||||||
i_starts = numpy.cumsum([0] + i_lengths[:-1])
|
i2j_miss = _get_regions(i2j, i_lengths)
|
||||||
j_starts = numpy.cumsum([0] + j_lengths[:-1])
|
j2i_miss = _get_regions(j2i, j_lengths)
|
||||||
i2j_miss = _get_regions(i2j, i_starts)
|
|
||||||
j2i_miss = _get_regions(j2i, j_starts)
|
|
||||||
|
|
||||||
i2j_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
|
i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
|
||||||
j2i_multi = _get_mapping(j2i_miss, i2j_miss, j_lengths, i_lengths)
|
|
||||||
return i2j_multi, j2i_multi
|
return i2j_multi, j2i_multi
|
||||||
|
|
||||||
|
|
||||||
def _get_regions(alignment, starts):
|
def _get_regions(alignment, lengths):
|
||||||
regions = {}
|
regions = {}
|
||||||
start = None
|
start = None
|
||||||
|
offset = 0
|
||||||
for i in range(len(alignment)):
|
for i in range(len(alignment)):
|
||||||
if alignment[i] < 0:
|
if alignment[i] < 0:
|
||||||
if start is None:
|
if start is None:
|
||||||
start = starts[i]
|
start = offset
|
||||||
regions.setdefault(start, [])
|
regions.setdefault(start, [])
|
||||||
regions[start].append(i)
|
regions[start].append(i)
|
||||||
else:
|
else:
|
||||||
start = None
|
start = None
|
||||||
|
offset += lengths[i]
|
||||||
return regions
|
return regions
|
||||||
|
|
||||||
|
|
||||||
def _get_mapping(miss1, miss2, lengths1, lengths2):
|
def _get_mapping(miss1, miss2, lengths1, lengths2):
|
||||||
output = {}
|
i2j = {}
|
||||||
|
j2i = {}
|
||||||
for start, region1 in miss1.items():
|
for start, region1 in miss1.items():
|
||||||
region2 = miss2.get(start, [])
|
if not region1 or start not in miss2:
|
||||||
if len(region2) == 1:
|
continue
|
||||||
if sum(lengths1[i] for i in region1):
|
region2 = miss2[start]
|
||||||
for i in region1:
|
if sum(lengths1[i] for i in region1) == sum(lengths2[i] for i in region2):
|
||||||
output[i] = region2[0]
|
j = region2.pop(0)
|
||||||
return output
|
buff = []
|
||||||
|
# Consume tokens from region 1, until we meet the length of the
|
||||||
|
# first token in region2. If we do, align the tokens. If
|
||||||
|
# we exceed the length, break.
|
||||||
|
while region1:
|
||||||
|
buff.append(region1.pop(0))
|
||||||
|
if sum(lengths1[i] for i in buff) == lengths2[j]:
|
||||||
|
for i in buff:
|
||||||
|
i2j[i] = j
|
||||||
|
j2i[j] = buff[-1]
|
||||||
|
j += 1
|
||||||
|
buff = []
|
||||||
|
elif sum(lengths1[i] for i in buff) > lengths2[j]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if buff and sum(lengths1[i] for i in buff) == lengths2[j]:
|
||||||
|
for i in buff:
|
||||||
|
i2j[i] = j
|
||||||
|
j2i[j] = buff[-1]
|
||||||
|
return i2j, j2i
|
||||||
|
|
||||||
|
|
||||||
def _convert_sequence(seq):
|
def _convert_sequence(seq):
|
||||||
|
|
|
@ -63,8 +63,6 @@ def merge_sents(sents):
|
||||||
|
|
||||||
punct_re = re.compile(r'\W')
|
punct_re = re.compile(r'\W')
|
||||||
def align(cand_words, gold_words):
|
def align(cand_words, gold_words):
|
||||||
cand_words = [punct_re.sub('', w).lower() for w in cand_words]
|
|
||||||
gold_words = [punct_re.sub('', w).lower() for w in gold_words]
|
|
||||||
if cand_words == gold_words:
|
if cand_words == gold_words:
|
||||||
alignment = numpy.arange(len(cand_words))
|
alignment = numpy.arange(len(cand_words))
|
||||||
return 0, alignment, alignment, {}, {}
|
return 0, alignment, alignment, {}, {}
|
||||||
|
@ -389,7 +387,7 @@ cdef class GoldParse:
|
||||||
for i, gold_i in enumerate(self.cand_to_gold):
|
for i, gold_i in enumerate(self.cand_to_gold):
|
||||||
if doc[i].text.isspace():
|
if doc[i].text.isspace():
|
||||||
self.words[i] = doc[i].text
|
self.words[i] = doc[i].text
|
||||||
self.tags[i] = 'SP'
|
self.tags[i] = '_SP'
|
||||||
self.heads[i] = None
|
self.heads[i] = None
|
||||||
self.labels[i] = None
|
self.labels[i] = None
|
||||||
self.ner[i] = 'O'
|
self.ner[i] = 'O'
|
||||||
|
|
Loading…
Reference in New Issue