Show warning if phrase pattern Doc was overprocessed (#3255)

In most cases, the PhraseMatcher will match on the verbatim token text or as of v2.1, sometimes the lowercase text. This means that we only need a tokenized Doc, without any other attributes.

If phrase patterns are created by processing large terminology lists with the full `nlp` object, this easily can make things a lot slower, because all components will be applied, even if we don't actually need the attributes they set (like part-of-speech tags, dependency labels).

The warning message also includes a suggestion to use nlp.make_doc or nlp.tokenizer.pipe for even faster processing. For now, the validation has to be enabled explicitly by setting validate=True.
This commit is contained in:
Ines Montani 2019-02-12 15:45:31 +01:00 committed by Matthew Honnibal
parent 6ec834dc72
commit ad2a514cdf
3 changed files with 37 additions and 3 deletions

View File

@ -60,6 +60,13 @@ class Warnings(object):
"make displaCy start another one. Instead, you should be able to " "make displaCy start another one. Instead, you should be able to "
"replace displacy.serve with displacy.render to show the " "replace displacy.serve with displacy.render to show the "
"visualization.") "visualization.")
W012 = ("A Doc object you're adding to the PhraseMatcher for pattern "
"'{key}' is parsed and/or tagged, but to match on '{attr}', you "
"don't actually need this information. This means that creating "
"the patterns is potentially much slower, because all pipeline "
"components are applied. To only create tokenized Doc objects, "
"try using `nlp.make_doc(text)` or process all texts as a stream "
"using `list(nlp.tokenizer.pipe(all_texts))`.")
@add_codes @add_codes

View File

@ -7,12 +7,12 @@ from murmurhash.mrmr cimport hash64
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from .matcher cimport Matcher from .matcher cimport Matcher
from ..attrs cimport ORTH, attr_id_t from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr from ..tokens.doc cimport Doc, get_token_attr
from ..typedefs cimport attr_t, hash_t from ..typedefs cimport attr_t, hash_t
from ..errors import Warnings, deprecation_warning from ..errors import Warnings, deprecation_warning, user_warning
from ..attrs import FLAG61 as U_ENT from ..attrs import FLAG61 as U_ENT
from ..attrs import FLAG60 as B2_ENT from ..attrs import FLAG60 as B2_ENT
from ..attrs import FLAG59 as B3_ENT from ..attrs import FLAG59 as B3_ENT
@ -33,8 +33,9 @@ cdef class PhraseMatcher:
cdef attr_id_t attr cdef attr_id_t attr
cdef public object _callbacks cdef public object _callbacks
cdef public object _patterns cdef public object _patterns
cdef public object _validate
def __init__(self, Vocab vocab, max_length=0, attr='ORTH'): def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False):
if max_length != 0: if max_length != 0:
deprecation_warning(Warnings.W010) deprecation_warning(Warnings.W010)
self.mem = Pool() self.mem = Pool()
@ -54,6 +55,7 @@ cdef class PhraseMatcher:
] ]
self.matcher.add('Candidate', None, *abstract_patterns) self.matcher.add('Candidate', None, *abstract_patterns)
self._callbacks = {} self._callbacks = {}
self._validate = validate
def __len__(self): def __len__(self):
"""Get the number of rules added to the matcher. Note that this only """Get the number of rules added to the matcher. Note that this only
@ -95,6 +97,10 @@ cdef class PhraseMatcher:
length = doc.length length = doc.length
if length == 0: if length == 0:
continue continue
if self._validate and (doc.is_tagged or doc.is_parsed) \
and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr))
tags = get_bilou(length) tags = get_bilou(length)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t)) phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, tag in enumerate(tags): for i, tag in enumerate(tags):

View File

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest
from spacy.matcher import PhraseMatcher from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc from spacy.tokens import Doc
from ..util import get_doc from ..util import get_doc
@ -78,3 +79,23 @@ def test_phrase_matcher_bool_attrs(en_vocab):
assert end1 == 3 assert end1 == 3
assert start2 == 3 assert start2 == 3
assert end2 == 6 assert end2 == 6
def test_phrase_matcher_validation(en_vocab):
doc1 = Doc(en_vocab, words=["Test"])
doc1.is_parsed = True
doc2 = Doc(en_vocab, words=["Test"])
doc2.is_tagged = True
doc3 = Doc(en_vocab, words=["Test"])
matcher = PhraseMatcher(en_vocab, validate=True)
with pytest.warns(UserWarning):
matcher.add("TEST1", None, doc1)
with pytest.warns(UserWarning):
matcher.add("TEST2", None, doc2)
with pytest.warns(None) as record:
matcher.add("TEST3", None, doc3)
assert not record.list
matcher = PhraseMatcher(en_vocab, attr="POS", validate=True)
with pytest.warns(None) as record:
matcher.add("TEST4", None, doc2)
assert not record.list