support IS_SENT_START in PhraseMatcher (#6771)

* support IS_SENT_START in PhraseMatcher

* add unit test and friendlier error

* use IDS.get instead
This commit is contained in:
Sofie Van Landeghem 2021-01-21 09:59:17 +01:00 committed by GitHub
parent bc7d83d4be
commit fdf8c77630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 1 deletions

View File

@ -8,6 +8,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
import warnings
from ..attrs import IDS
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC
from ..tokens.token cimport Token
@ -58,9 +59,11 @@ cdef class PhraseMatcher:
attr = attr.upper()
if attr == "TEXT":
attr = "ORTH"
if attr == "IS_SENT_START":
attr = "SENT_START"
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr]
self.attr = IDS.get(attr)
def __len__(self):
"""Get the number of match IDs added to the matcher.

View File

@ -290,3 +290,8 @@ def test_phrase_matcher_pickle(en_vocab):
# clunky way to vaguely check that callback is unpickled
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
assert isinstance(callbacks.get("TEST2"), Mock)
@pytest.mark.parametrize("attr", ["SENT_START", "IS_SENT_START"])
def test_phrase_matcher_sent_start(en_vocab, attr):
matcher = PhraseMatcher(en_vocab, attr=attr)