mirror of https://github.com/explosion/spaCy.git
💫 Add token match pattern validation via JSON schemas (#3244)
* Add custom MatchPatternError * Improve validators and add validation option to Matcher * Adjust formatting * Never validate in Matcher within PhraseMatcher If we do decide to make validate default to True, the PhraseMatcher's Matcher shouldn't ever validate. Here, we create the patterns automatically anyways (and it's currently unclear whether the validation has performance impacts at a very large scale).
This commit is contained in:
parent
ad2a514cdf
commit
483dddc9bc
|
@ -74,8 +74,8 @@ def debug_data(
|
|||
|
||||
# Validate data format using the JSON schema
|
||||
# TODO: update once the new format is ready
|
||||
train_data_errors = [] # TODO: validate_json(train_data, schema)
|
||||
dev_data_errors = [] # TODO: validate_json(dev_data, schema)
|
||||
train_data_errors = [] # TODO: validate_json
|
||||
dev_data_errors = [] # TODO: validate_json
|
||||
if not train_data_errors:
|
||||
msg.good("Training data JSON format is valid")
|
||||
if not dev_data_errors:
|
||||
|
|
|
@ -325,6 +325,21 @@ class TempErrors(object):
|
|||
# fmt: on
|
||||
|
||||
|
||||
class MatchPatternError(ValueError):
|
||||
def __init__(self, key, errors):
|
||||
"""Custom error for validating match patterns.
|
||||
|
||||
key (unicode): The name of the matcher rule.
|
||||
errors (dict): Validation errors (sequence of strings) mapped to pattern
|
||||
ID, i.e. the index of the added pattern.
|
||||
"""
|
||||
msg = "Invalid token patterns for matcher rule '{}'\n".format(key)
|
||||
for pattern_idx, error_msgs in errors.items():
|
||||
pattern_errors = "\n".join(["- {}".format(e) for e in error_msgs])
|
||||
msg += "\nPattern {}:\n{}\n".format(pattern_idx, pattern_errors)
|
||||
ValueError.__init__(self, msg)
|
||||
|
||||
|
||||
class ModelsWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
TOKEN_PATTERN_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-06/schema",
|
||||
"definitions": {
|
||||
"string_value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {"type": "array", "items": {"type": "string"}},
|
||||
"NOT_IN": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
"integer_value": {
|
||||
"anyOf": [
|
||||
{"type": "integer"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {"type": "array", "items": {"type": "integer"}},
|
||||
"NOT_IN": {"type": "array", "items": {"type": "integer"}},
|
||||
"==": {"type": "integer"},
|
||||
">=": {"type": "integer"},
|
||||
"<=": {"type": "integer"},
|
||||
">": {"type": "integer"},
|
||||
"<": {"type": "integer"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
"boolean_value": {"type": "boolean"},
|
||||
"underscore_value": {
|
||||
"anyOf": [
|
||||
{"type": ["string", "integer", "number", "array", "boolean", "null"]},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
},
|
||||
"NOT_IN": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
},
|
||||
"==": {"type": "integer"},
|
||||
">=": {"type": "integer"},
|
||||
"<=": {"type": "integer"},
|
||||
">": {"type": "integer"},
|
||||
"<": {"type": "integer"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ORTH": {
|
||||
"title": "Verbatim token text",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"TEXT": {
|
||||
"title": "Verbatim token text (spaCy v2.1+)",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"LOWER": {
|
||||
"title": "Lowercase form of token text",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"POS": {
|
||||
"title": "Coarse-grained part-of-speech tag",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"TAG": {
|
||||
"title": "Fine-grained part-of-speech tag",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"DEP": {"title": "Dependency label", "$ref": "#/definitions/string_value"},
|
||||
"LEMMA": {
|
||||
"title": "Lemma (base form)",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"SHAPE": {
|
||||
"title": "Abstract token shape",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"ENT_TYPE": {
|
||||
"title": "Entity label of single token",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"LENGTH": {
|
||||
"title": "Token character length",
|
||||
"$ref": "#/definitions/integer_value",
|
||||
},
|
||||
"IS_ALPHA": {
|
||||
"title": "Token consists of alphanumeric characters",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_ASCII": {
|
||||
"title": "Token consists of ASCII characters",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_DIGIT": {
|
||||
"title": "Token consists of digits",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_LOWER": {
|
||||
"title": "Token is lowercase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_UPPER": {
|
||||
"title": "Token is uppercase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_TITLE": {
|
||||
"title": "Token is titlecase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_PUNCT": {
|
||||
"title": "Token is punctuation",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_SPACE": {
|
||||
"title": "Token is whitespace",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_STOP": {
|
||||
"title": "Token is stop word",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_NUM": {
|
||||
"title": "Token resembles a number",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_URL": {
|
||||
"title": "Token resembles a URL",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_EMAIL": {
|
||||
"title": "Token resembles an email address",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"_": {
|
||||
"title": "Custom extension token attributes (token._.)",
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^.*$": {"$ref": "#/definitions/underscore_value"}
|
||||
},
|
||||
},
|
||||
"OP": {
|
||||
"title": "Operators / quantifiers",
|
||||
"type": "string",
|
||||
"enum": ["+", "*", "?", "!"],
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
|
@ -62,6 +62,7 @@ cdef class Matcher:
|
|||
cdef Pool mem
|
||||
cdef vector[TokenPatternC*] patterns
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object validator
|
||||
cdef public object _patterns
|
||||
cdef public object _callbacks
|
||||
cdef public object _extensions
|
||||
|
|
|
@ -17,7 +17,9 @@ from ..tokens.doc cimport Doc, get_token_attr
|
|||
from ..tokens.token cimport Token
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
||||
|
||||
from ..errors import Errors
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..util import get_json_validator, validate_json
|
||||
from ..errors import Errors, MatchPatternError
|
||||
from ..strings import get_string_id
|
||||
from ..attrs import IDS
|
||||
|
||||
|
@ -579,7 +581,7 @@ def _get_extensions(spec, string_store, name2index):
|
|||
cdef class Matcher:
|
||||
"""Match sequences of tokens, based on pattern rules."""
|
||||
|
||||
def __init__(self, vocab):
|
||||
def __init__(self, vocab, validate=False):
|
||||
"""Create the Matcher.
|
||||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
|
@ -593,6 +595,7 @@ cdef class Matcher:
|
|||
self._extra_predicates = []
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA) if validate else None
|
||||
|
||||
def __reduce__(self):
|
||||
data = (self.vocab, self._patterns, self._callbacks)
|
||||
|
@ -643,9 +646,14 @@ cdef class Matcher:
|
|||
on_match (callable): Callback executed on match.
|
||||
*patterns (list): List of token descriptions.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
errors = {}
|
||||
for i, pattern in enumerate(patterns):
|
||||
if len(pattern) == 0:
|
||||
raise ValueError(Errors.E012.format(key=key))
|
||||
if self.validator:
|
||||
errors[i] = validate_json(pattern, self.validator)
|
||||
if errors:
|
||||
raise MatchPatternError(key, errors)
|
||||
key = self._normalize_key(key)
|
||||
for pattern in patterns:
|
||||
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
||||
|
|
|
@ -41,7 +41,7 @@ cdef class PhraseMatcher:
|
|||
self.mem = Pool()
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab)
|
||||
self.matcher = Matcher(self.vocab, validate=False)
|
||||
if isinstance(attr, long):
|
||||
self.attr = attr
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
from spacy.cli._schemas import TRAINING_SCHEMA
|
||||
from spacy.util import validate_json
|
||||
from spacy.util import get_json_validator, validate_json
|
||||
from spacy.tokens import Doc
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -62,5 +62,6 @@ def test_doc_to_json_underscore_error_serialize(doc):
|
|||
|
||||
def test_doc_to_json_valid_training(doc):
|
||||
json_doc = doc.to_json()
|
||||
errors = validate_json([json_doc], TRAINING_SCHEMA)
|
||||
validator = get_json_validator(TRAINING_SCHEMA)
|
||||
errors = validate_json([json_doc], validator)
|
||||
assert not errors
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from spacy.errors import MatchPatternError
|
||||
from spacy.util import get_json_validator, validate_json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
return get_json_validator(TOKEN_PATTERN_SCHEMA)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
|
||||
)
|
||||
def test_matcher_pattern_validation(en_vocab, pattern):
|
||||
matcher = Matcher(en_vocab, validate=True)
|
||||
with pytest.raises(MatchPatternError):
|
||||
matcher.add("TEST", None, pattern)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern,n_errors",
|
||||
[
|
||||
# Bad patterns
|
||||
([{"XX": "foo"}], 1),
|
||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2),
|
||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2),
|
||||
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2),
|
||||
([{"TEXT": {"VALUE": "foo"}}], 1),
|
||||
([{"LENGTH": {"VALUE": 5}}], 1),
|
||||
([{"_": "foo"}], 1),
|
||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1),
|
||||
([{"IS_PUNCT": True, "OP": "$"}], 1),
|
||||
# Good patterns
|
||||
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0),
|
||||
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0),
|
||||
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0),
|
||||
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0),
|
||||
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0),
|
||||
],
|
||||
)
|
||||
def test_pattern_validation(validator, pattern, n_errors):
|
||||
errors = validate_json(pattern, validator)
|
||||
assert len(errors) == n_errors
|
|
@ -1,18 +1,24 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.util import validate_json, validate_schema
|
||||
from spacy.util import get_json_validator, validate_json, validate_schema
|
||||
from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA
|
||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def training_schema_validator():
|
||||
return get_json_validator(TRAINING_SCHEMA)
|
||||
|
||||
|
||||
def test_validate_schema():
|
||||
validate_schema({"type": "object"})
|
||||
with pytest.raises(Exception):
|
||||
validate_schema({"type": lambda x: x})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA])
|
||||
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA, TOKEN_PATTERN_SCHEMA])
|
||||
def test_schemas(schema):
|
||||
validate_schema(schema)
|
||||
|
||||
|
@ -24,8 +30,8 @@ def test_schemas(schema):
|
|||
{"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]},
|
||||
],
|
||||
)
|
||||
def test_json_schema_training_valid(data):
|
||||
errors = validate_json([data], TRAINING_SCHEMA)
|
||||
def test_json_schema_training_valid(data, training_schema_validator):
|
||||
errors = validate_json([data], training_schema_validator)
|
||||
assert not errors
|
||||
|
||||
|
||||
|
@ -39,6 +45,6 @@ def test_json_schema_training_valid(data):
|
|||
({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2),
|
||||
],
|
||||
)
|
||||
def test_json_schema_training_invalid(data, n_errors):
|
||||
errors = validate_json([data], TRAINING_SCHEMA)
|
||||
def test_json_schema_training_invalid(data, n_errors, training_schema_validator):
|
||||
errors = validate_json([data], training_schema_validator)
|
||||
assert len(errors) == n_errors
|
||||
|
|
|
@ -627,28 +627,38 @@ def fix_random_seed(seed=0):
|
|||
cupy.random.seed(seed)
|
||||
|
||||
|
||||
def validate_schema(schema):
|
||||
def get_json_validator(schema):
|
||||
# We're using a helper function here to make it easier to change the
|
||||
# validator that's used (e.g. different draft implementation), without
|
||||
# having to change it all across the codebase.
|
||||
# TODO: replace with (stable) Draft6Validator, if available
|
||||
validator = Draft4Validator(schema)
|
||||
return Draft4Validator(schema)
|
||||
|
||||
|
||||
def validate_schema(schema):
|
||||
"""Validate a given schema. This just checks if the schema itself is valid."""
|
||||
validator = get_json_validator(schema)
|
||||
validator.check_schema(schema)
|
||||
|
||||
|
||||
def validate_json(data, schema):
|
||||
def validate_json(data, validator):
|
||||
"""Validate data against a given JSON schema (see https://json-schema.org).
|
||||
|
||||
data: JSON-serializable data to validate.
|
||||
schema (dict): The JSON schema.
|
||||
validator (jsonschema.DraftXValidator): The validator.
|
||||
RETURNS (list): A list of error messages, if available.
|
||||
"""
|
||||
# TODO: replace with (stable) Draft6Validator, if available
|
||||
validator = Draft4Validator(schema)
|
||||
errors = []
|
||||
for err in sorted(validator.iter_errors(data), key=lambda e: e.path):
|
||||
if err.path:
|
||||
err_path = "[{}]".format(" -> ".join([str(p) for p in err.path]))
|
||||
else:
|
||||
err_path = ""
|
||||
errors.append(err.message + " " + err_path)
|
||||
msg = err.message + " " + err_path
|
||||
if err.context: # Error has suberrors, e.g. if schema uses anyOf
|
||||
suberrs = [" - {}".format(suberr.message) for suberr in err.context]
|
||||
msg += ":\n{}".format("".join(suberrs))
|
||||
errors.append(msg)
|
||||
return errors
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue