mirror of https://github.com/explosion/spaCy.git
state_type as Literal
This commit is contained in:
parent
35dbc63578
commit
5a9fdbc8ad
|
@ -2,7 +2,8 @@ from typing import Optional, List
|
|||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ... import Errors
|
||||
from ...errors import Errors
|
||||
from ...compat import Literal
|
||||
from ...util import registry
|
||||
from .._precomputable_affine import PrecomputableAffine
|
||||
from ..tb_framework import TransitionModel
|
||||
|
@ -12,7 +13,7 @@ from ...tokens import Doc
|
|||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||
def build_tb_parser_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
state_type: str,
|
||||
state_type: Literal["parser", "ner"],
|
||||
extra_state_tokens: bool,
|
||||
hidden_width: int,
|
||||
maxout_pieces: int,
|
||||
|
|
|
@ -345,3 +345,13 @@ def test_config_auto_fill_extra_fields():
|
|||
assert "extra" not in nlp.config["training"]
|
||||
# Make sure the config generated is valid
|
||||
load_model_from_config(nlp.config)
|
||||
|
||||
|
||||
def test_config_validate_literal():
|
||||
nlp = English()
|
||||
config = Config().from_str(parser_config_string)
|
||||
config["model"]["state_type"] = "nonsense"
|
||||
with pytest.raises(ConfigValidationError):
|
||||
nlp.add_pipe("parser", config=config)
|
||||
config["model"]["state_type"] = "ner"
|
||||
nlp.add_pipe("parser", config=config)
|
Loading…
Reference in New Issue