support IS_SENT_START in PhraseMatcher (#6771)

* support IS_SENT_START in PhraseMatcher

* add unit test and friendlier error

* use IDS.get instead
This commit is contained in:
Sofie Van Landeghem 2021-01-21 09:59:17 +01:00 committed by GitHub
parent bc7d83d4be
commit fdf8c77630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 1 deletions

View File

@ -8,6 +8,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
import warnings import warnings
from ..attrs import IDS
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC from ..structs cimport TokenC
from ..tokens.token cimport Token from ..tokens.token cimport Token
@ -58,9 +59,11 @@ cdef class PhraseMatcher:
attr = attr.upper() attr = attr.upper()
if attr == "TEXT": if attr == "TEXT":
attr = "ORTH" attr = "ORTH"
if attr == "IS_SENT_START":
attr = "SENT_START"
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]: if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr)) raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr] self.attr = IDS.get(attr)
def __len__(self): def __len__(self):
"""Get the number of match IDs added to the matcher. """Get the number of match IDs added to the matcher.

View File

@ -290,3 +290,8 @@ def test_phrase_matcher_pickle(en_vocab):
# clunky way to vaguely check that callback is unpickled # clunky way to vaguely check that callback is unpickled
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1] (vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
assert isinstance(callbacks.get("TEST2"), Mock) assert isinstance(callbacks.get("TEST2"), Mock)
@pytest.mark.parametrize("attr", ["SENT_START", "IS_SENT_START"])
def test_phrase_matcher_sent_start(en_vocab, attr):
matcher = PhraseMatcher(en_vocab, attr=attr)