mirror of https://github.com/explosion/spaCy.git
Fix offsets in Span.get_lca_matrix (#8116)
* Fix range in Span.get_lca_matrix Fix the adjusted token index / lca matrix index ranges for `_get_lca_matrix` for spans. * The range for `k` should correspond to the adjusted indices in `lca_matrix` with the `start` indexed at `0` * Update test for v3.x
This commit is contained in:
parent
0dffc5d9e2
commit
2c545c4c5b
|
@ -1,4 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy
|
||||||
|
from numpy.testing import assert_array_equal
|
||||||
from spacy.attrs import ORTH, LENGTH
|
from spacy.attrs import ORTH, LENGTH
|
||||||
from spacy.tokens import Doc, Span, Token
|
from spacy.tokens import Doc, Span, Token
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
@ -120,6 +122,17 @@ def test_spans_lca_matrix(en_tokenizer):
|
||||||
assert lca[1, 0] == 1 # slept & dog -> slept
|
assert lca[1, 0] == 1 # slept & dog -> slept
|
||||||
assert lca[1, 1] == 1 # slept & slept -> slept
|
assert lca[1, 1] == 1 # slept & slept -> slept
|
||||||
|
|
||||||
|
# example from Span API docs
|
||||||
|
tokens = en_tokenizer("I like New York in Autumn")
|
||||||
|
doc = Doc(
|
||||||
|
tokens.vocab,
|
||||||
|
words=[t.text for t in tokens],
|
||||||
|
heads=[1, 1, 3, 1, 3, 4],
|
||||||
|
deps=["dep"] * len(tokens),
|
||||||
|
)
|
||||||
|
lca = doc[1:4].get_lca_matrix()
|
||||||
|
assert_array_equal(lca, numpy.asarray([[0, 0, 0], [0, 1, 2], [0, 2, 2]]))
|
||||||
|
|
||||||
|
|
||||||
def test_span_similarity_match():
|
def test_span_similarity_match():
|
||||||
doc = Doc(Vocab(), words=["a", "b", "a", "b"])
|
doc = Doc(Vocab(), words=["a", "b", "a", "b"])
|
||||||
|
|
|
@ -1673,7 +1673,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
||||||
j_idx_in_sent = start + j - sent_start
|
j_idx_in_sent = start + j - sent_start
|
||||||
n_missing_tokens_in_sent = len(sent) - j_idx_in_sent
|
n_missing_tokens_in_sent = len(sent) - j_idx_in_sent
|
||||||
# make sure we do not go past `end`, in cases where `end` < sent.end
|
# make sure we do not go past `end`, in cases where `end` < sent.end
|
||||||
max_range = min(j + n_missing_tokens_in_sent, end)
|
max_range = min(j + n_missing_tokens_in_sent, end - start)
|
||||||
for k in range(j + 1, max_range):
|
for k in range(j + 1, max_range):
|
||||||
lca = _get_tokens_lca(token_j, doc[start + k])
|
lca = _get_tokens_lca(token_j, doc[start + k])
|
||||||
# if lca is outside of span, we set it to -1
|
# if lca is outside of span, we set it to -1
|
||||||
|
|
Loading…
Reference in New Issue