Fix many-to-one alignment

This commit is contained in:
Matthew Honnibal 2018-02-24 16:03:50 +01:00
parent 4890ee1732
commit 6138439469
2 changed files with 36 additions and 19 deletions

View File

@ -90,7 +90,7 @@ from .compat import unicode_
from murmurhash.mrmr cimport hash32
def align(S, T, many_to_one=False, one_to_many=False):
def align(S, T):
cdef int m = len(S)
cdef int n = len(T)
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
@ -126,39 +126,58 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths):
i2j_multi: {1: 1, 2: 1}
j2i_multi: {}
'''
i_starts = numpy.cumsum([0] + i_lengths[:-1])
j_starts = numpy.cumsum([0] + j_lengths[:-1])
i2j_miss = _get_regions(i2j, i_starts)
j2i_miss = _get_regions(j2i, j_starts)
i2j_miss = _get_regions(i2j, i_lengths)
j2i_miss = _get_regions(j2i, j_lengths)
i2j_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
j2i_multi = _get_mapping(j2i_miss, i2j_miss, j_lengths, i_lengths)
i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
return i2j_multi, j2i_multi
def _get_regions(alignment, starts):
def _get_regions(alignment, lengths):
regions = {}
start = None
offset = 0
for i in range(len(alignment)):
if alignment[i] < 0:
if start is None:
start = starts[i]
start = offset
regions.setdefault(start, [])
regions[start].append(i)
else:
start = None
offset += lengths[i]
return regions
def _get_mapping(miss1, miss2, lengths1, lengths2):
output = {}
i2j = {}
j2i = {}
for start, region1 in miss1.items():
region2 = miss2.get(start, [])
if len(region2) == 1:
if sum(lengths1[i] for i in region1):
for i in region1:
output[i] = region2[0]
return output
if not region1 or start not in miss2:
continue
region2 = miss2[start]
if sum(lengths1[i] for i in region1) == sum(lengths2[i] for i in region2):
j = region2.pop(0)
buff = []
# Consume tokens from region 1, until we meet the length of the
# first token in region2. If we do, align the tokens. If
# we exceed the length, break.
while region1:
buff.append(region1.pop(0))
if sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff:
i2j[i] = j
j2i[j] = buff[-1]
j += 1
buff = []
elif sum(lengths1[i] for i in buff) > lengths2[j]:
break
else:
if buff and sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff:
i2j[i] = j
j2i[j] = buff[-1]
return i2j, j2i
def _convert_sequence(seq):

View File

@ -63,8 +63,6 @@ def merge_sents(sents):
punct_re = re.compile(r'\W')
def align(cand_words, gold_words):
cand_words = [punct_re.sub('', w).lower() for w in cand_words]
gold_words = [punct_re.sub('', w).lower() for w in gold_words]
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}
@ -389,7 +387,7 @@ cdef class GoldParse:
for i, gold_i in enumerate(self.cand_to_gold):
if doc[i].text.isspace():
self.words[i] = doc[i].text
self.tags[i] = 'SP'
self.tags[i] = '_SP'
self.heads[i] = None
self.labels[i] = None
self.ner[i] = 'O'