mirror of https://github.com/explosion/spaCy.git
Expand initialize/training config validation
Validate both `[initialize]` and `[training]` in `debug data` and `nlp.initialize()` with separate config validation error blocks that indicate which block of the config is being validated.
This commit is contained in:
parent
ad43cbb042
commit
5fb8b7037a
|
@ -7,7 +7,7 @@ import typer
|
|||
|
||||
from ._util import Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
|
||||
from ..util import registry
|
||||
from .. import util
|
||||
|
||||
|
@ -55,6 +55,11 @@ def debug_config(
|
|||
config = util.load_config(config_path, overrides=overrides)
|
||||
nlp = util.load_model_from_config(config)
|
||||
config = nlp.config.interpolate()
|
||||
msg.divider("Config validation for [initialize]")
|
||||
with show_validation_error(config_path):
|
||||
T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
|
||||
msg.divider("Config validation for [training]")
|
||||
with show_validation_error(config_path):
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
util.resolve_dot_names(config, dot_names)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||
from thinc.api import ConfigValidationError
|
||||
from pathlib import Path
|
||||
|
@ -12,7 +13,7 @@ import tqdm
|
|||
from ..lookups import Lookups
|
||||
from ..vectors import Vectors
|
||||
from ..errors import Errors
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
|
||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
|
||||
|
||||
|
@ -23,6 +24,9 @@ if TYPE_CHECKING:
|
|||
def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||
raw_config = config
|
||||
config = raw_config.interpolate()
|
||||
# Validate config before accessing values
|
||||
_validate_config_block(config, "initialize", ConfigSchemaInit)
|
||||
_validate_config_block(config, "training", ConfigSchemaTraining)
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
|
@ -269,3 +273,13 @@ def ensure_shape(lines):
|
|||
length = len(captured)
|
||||
yield f"{length} {width}"
|
||||
yield from captured
|
||||
|
||||
|
||||
def _validate_config_block(config: Config, block: str, schema: BaseModel):
|
||||
try:
|
||||
registry.resolve(config[block], validate=True, schema=schema)
|
||||
except ConfigValidationError as e:
|
||||
title = f"Config validation error for [{block}]"
|
||||
desc = "For more information run: python -m spacy debug config config.cfg"
|
||||
err = ConfigValidationError.from_error(e, title=title, desc=desc)
|
||||
raise err from None
|
||||
|
|
Loading…
Reference in New Issue