diff --git a/spacy/_align.pyx b/spacy/_align.pyx index 718c0ff90..07b6efbd4 100644 --- a/spacy/_align.pyx +++ b/spacy/_align.pyx @@ -90,7 +90,7 @@ from .compat import unicode_ from murmurhash.mrmr cimport hash32 -def align(S, T, many_to_one=False, one_to_many=False): +def align(S, T): cdef int m = len(S) cdef int n = len(T) cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32') @@ -126,39 +126,58 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths): i2j_multi: {1: 1, 2: 1} j2i_multi: {} ''' - i_starts = numpy.cumsum([0] + i_lengths[:-1]) - j_starts = numpy.cumsum([0] + j_lengths[:-1]) - i2j_miss = _get_regions(i2j, i_starts) - j2i_miss = _get_regions(j2i, j_starts) + i2j_miss = _get_regions(i2j, i_lengths) + j2i_miss = _get_regions(j2i, j_lengths) - i2j_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths) - j2i_multi = _get_mapping(j2i_miss, i2j_miss, j_lengths, i_lengths) + i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths) return i2j_multi, j2i_multi -def _get_regions(alignment, starts): +def _get_regions(alignment, lengths): regions = {} start = None + offset = 0 for i in range(len(alignment)): if alignment[i] < 0: if start is None: - start = starts[i] + start = offset regions.setdefault(start, []) regions[start].append(i) else: start = None + offset += lengths[i] return regions def _get_mapping(miss1, miss2, lengths1, lengths2): - output = {} + i2j = {} + j2i = {} for start, region1 in miss1.items(): - region2 = miss2.get(start, []) - if len(region2) == 1: - if sum(lengths1[i] for i in region1): - for i in region1: - output[i] = region2[0] - return output + if not region1 or start not in miss2: + continue + region2 = miss2[start] + if sum(lengths1[i] for i in region1) == sum(lengths2[i] for i in region2): + j = region2.pop(0) + buff = [] + # Consume tokens from region 1, until we meet the length of the + # first token in region2. If we do, align the tokens. If + # we exceed the length, break. + while region1: + buff.append(region1.pop(0)) + if sum(lengths1[i] for i in buff) == lengths2[j]: + for i in buff: + i2j[i] = j + j2i[j] = buff[-1] + j += 1 + buff = [] + elif sum(lengths1[i] for i in buff) > lengths2[j]: + break + else: + if buff and sum(lengths1[i] for i in buff) == lengths2[j]: + for i in buff: + i2j[i] = j + j2i[j] = buff[-1] + return i2j, j2i def _convert_sequence(seq): diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 56a4f971b..f6bf38700 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -63,8 +63,6 @@ def merge_sents(sents): punct_re = re.compile(r'\W') def align(cand_words, gold_words): - cand_words = [punct_re.sub('', w).lower() for w in cand_words] - gold_words = [punct_re.sub('', w).lower() for w in gold_words] if cand_words == gold_words: alignment = numpy.arange(len(cand_words)) return 0, alignment, alignment, {}, {} @@ -389,7 +387,7 @@ cdef class GoldParse: for i, gold_i in enumerate(self.cand_to_gold): if doc[i].text.isspace(): self.words[i] = doc[i].text - self.tags[i] = 'SP' + self.tags[i] = '_SP' self.heads[i] = None self.labels[i] = None self.ner[i] = 'O'