From 8f8a7f1733513b868272718006a579b251a7bdce Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 8 Dec 2020 17:37:20 +0100 Subject: [PATCH] returning config in init_config --- spacy/cli/init_config.py | 8 +++++--- spacy/tests/test_cli.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index ff11f97f6..615132deb 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -45,7 +45,7 @@ def init_config_cli( if isinstance(optimize, Optimizations): # instance of enum from the CLI optimize = optimize.value pipeline = string_to_list(pipeline) - init_config( + config = init_config( output_file, lang=lang, pipeline=pipeline, @@ -53,6 +53,8 @@ def init_config_cli( cpu=cpu, pretraining=pretraining, ) + is_stdout = str(output_file) == "-" + save_config(config, output_file, is_stdout=is_stdout) @init_cli.command("fill-config") @@ -125,7 +127,7 @@ def init_config( optimize: str, cpu: bool, pretraining: bool = False, -) -> None: +) -> Config: is_stdout = str(output_file) == "-" msg = Printer(no_print=is_stdout) with TEMPLATE_PATH.open("r") as f: @@ -173,7 +175,7 @@ def init_config( pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) config = pretrain_config.merge(config) msg.good("Auto-filled config with all values") - save_config(config, output_file, is_stdout=is_stdout) + return config def save_config( diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 62584d0ce..f7462d88e 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -8,7 +8,7 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import string_to_list -from thinc.api import ConfigValidationError +from thinc.api import ConfigValidationError, Config import srsly import os @@ -368,7 +368,8 @@ def test_parse_cli_overrides(): @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) def test_init_config(lang, pipeline, optimize): # TODO: add more tests and also check for GPU with transformers - init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True) + config = init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True) + assert isinstance(config, Config) def test_model_recommendations():