returning config in init_config

This commit is contained in:
svlandeg 2020-12-08 17:37:20 +01:00
parent 8921364579
commit 8f8a7f1733
2 changed files with 8 additions and 5 deletions

View File

@ -45,7 +45,7 @@ def init_config_cli(
if isinstance(optimize, Optimizations): # instance of enum from the CLI if isinstance(optimize, Optimizations): # instance of enum from the CLI
optimize = optimize.value optimize = optimize.value
pipeline = string_to_list(pipeline) pipeline = string_to_list(pipeline)
init_config( config = init_config(
output_file, output_file,
lang=lang, lang=lang,
pipeline=pipeline, pipeline=pipeline,
@ -53,6 +53,8 @@ def init_config_cli(
cpu=cpu, cpu=cpu,
pretraining=pretraining, pretraining=pretraining,
) )
is_stdout = str(output_file) == "-"
save_config(config, output_file, is_stdout=is_stdout)
@init_cli.command("fill-config") @init_cli.command("fill-config")
@ -125,7 +127,7 @@ def init_config(
optimize: str, optimize: str,
cpu: bool, cpu: bool,
pretraining: bool = False, pretraining: bool = False,
) -> None: ) -> Config:
is_stdout = str(output_file) == "-" is_stdout = str(output_file) == "-"
msg = Printer(no_print=is_stdout) msg = Printer(no_print=is_stdout)
with TEMPLATE_PATH.open("r") as f: with TEMPLATE_PATH.open("r") as f:
@ -173,7 +175,7 @@ def init_config(
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
config = pretrain_config.merge(config) config = pretrain_config.merge(config)
msg.good("Auto-filled config with all values") msg.good("Auto-filled config with all values")
save_config(config, output_file, is_stdout=is_stdout) return config
def save_config( def save_config(

View File

@ -8,7 +8,7 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list from spacy.cli._util import string_to_list
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError, Config
import srsly import srsly
import os import os
@ -368,7 +368,8 @@ def test_parse_cli_overrides():
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
def test_init_config(lang, pipeline, optimize): def test_init_config(lang, pipeline, optimize):
# TODO: add more tests and also check for GPU with transformers # TODO: add more tests and also check for GPU with transformers
init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True) config = init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True)
assert isinstance(config, Config)
def test_model_recommendations(): def test_model_recommendations():