diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index b72ee83ec..4596a9d32 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -5,6 +5,10 @@ import srsly import hashlib import typer from typer.main import get_command +from contextlib import contextmanager +from thinc.config import ConfigValidationError +from configparser import InterpolationError +import sys from ..schemas import ProjectConfigSchema, validate @@ -154,3 +158,17 @@ def get_checksum(path: Union[Path, str]) -> str: dir_checksum.update(sub_file.read_bytes()) return dir_checksum.hexdigest() raise ValueError(f"Can't get checksum for {path}: not a file or directory") + + +@contextmanager +def show_validation_error(title: str = "Config validation error"): + """Helper to show custom config validation errors on the CLI. + + title (str): Title of the custom formatted error. + """ + try: + yield + except (ConfigValidationError, InterpolationError) as e: + msg.fail(title, spaced=True) + print(str(e).replace("Config validation error", "").strip()) + sys.exit(1) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 4c339ae36..4707a8e98 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -6,13 +6,11 @@ from pathlib import Path from wasabi import msg import thinc import thinc.schedules -from thinc.config import ConfigValidationError from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed import random import typer -import sys -from ._util import app, Arg, Opt, parse_config_overrides +from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ..gold import Corpus, Example from ..lookups import Lookups from .. import util @@ -83,17 +81,13 @@ def train( ) -> None: msg.info(f"Loading config from: {config_path}") # Read the config first without creating objects, to get to the original nlp_config - try: + with show_validation_error(): config = util.load_config( config_path, create_objects=False, schema=ConfigSchema, overrides=config_overrides, ) - except ConfigValidationError as e: - msg.fail("Config validation error") - print(str(e).replace("Config validation error", "").strip()) - sys.exit(1) use_gpu = config["training"]["use_gpu"] if use_gpu >= 0: msg.info(f"Using GPU: {use_gpu}")