From 5a9fdbc8ad8e6e03968b78e026b8ee75e4c4a3e1 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 23 Sep 2020 17:32:14 +0200 Subject: [PATCH] state_type as Literal --- spacy/ml/models/parser.py | 5 +++-- spacy/tests/serialize/test_serialize_config.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py index dbea6b507..2c40bb3ab 100644 --- a/spacy/ml/models/parser.py +++ b/spacy/ml/models/parser.py @@ -2,7 +2,8 @@ from typing import Optional, List from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from thinc.types import Floats2d -from ... import Errors +from ...errors import Errors +from ...compat import Literal from ...util import registry from .._precomputable_affine import PrecomputableAffine from ..tb_framework import TransitionModel @@ -12,7 +13,7 @@ from ...tokens import Doc @registry.architectures.register("spacy.TransitionBasedParser.v1") def build_tb_parser_model( tok2vec: Model[List[Doc], List[Floats2d]], - state_type: str, + state_type: Literal["parser", "ner"], extra_state_tokens: bool, hidden_width: int, maxout_pieces: int, diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 10e0e132b..6aad59272 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -345,3 +345,13 @@ def test_config_auto_fill_extra_fields(): assert "extra" not in nlp.config["training"] # Make sure the config generated is valid load_model_from_config(nlp.config) + + +def test_config_validate_literal(): + nlp = English() + config = Config().from_str(parser_config_string) + config["model"]["state_type"] = "nonsense" + with pytest.raises(ConfigValidationError): + nlp.add_pipe("parser", config=config) + config["model"]["state_type"] = "ner" + nlp.add_pipe("parser", config=config) \ No newline at end of file