Move core training logic in CLI into standalone function (#9398)

This commit is contained in:
Ines Montani 2021-10-11 10:56:14 +02:00 committed by GitHub
parent 2a7e327310
commit 5003a9c3c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Dict, Any
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import typer import typer
@ -7,7 +7,7 @@ import sys
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code, setup_gpu from ._util import import_code, setup_gpu
from ..training.loop import train from ..training.loop import train as train_nlp
from ..training.initialize import init_nlp from ..training.initialize import init_nlp
from .. import util from .. import util
@ -40,6 +40,18 @@ def train_cli(
DOCS: https://spacy.io/api/cli#train DOCS: https://spacy.io/api/cli#train
""" """
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
def train(
config_path: Path,
output_path: Optional[Path] = None,
*,
use_gpu: int = -1,
overrides: Dict[str, Any] = util.SimpleFrozenDict(),
):
# Make sure all files and paths exists if they are needed # Make sure all files and paths exists if they are needed
if not config_path or (str(config_path) != "-" and not config_path.exists()): if not config_path or (str(config_path) != "-" and not config_path.exists()):
msg.fail("Config file not found", config_path, exits=1) msg.fail("Config file not found", config_path, exits=1)
@ -50,8 +62,6 @@ def train_cli(
output_path.mkdir(parents=True) output_path.mkdir(parents=True)
msg.good(f"Created output directory: {output_path}") msg.good(f"Created output directory: {output_path}")
msg.info(f"Saving to output directory: {output_path}") msg.info(f"Saving to output directory: {output_path}")
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
setup_gpu(use_gpu) setup_gpu(use_gpu)
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides, interpolate=False) config = util.load_config(config_path, overrides=overrides, interpolate=False)
@ -60,4 +70,4 @@ def train_cli(
nlp = init_nlp(config, use_gpu=use_gpu) nlp = init_nlp(config, use_gpu=use_gpu)
msg.good("Initialized pipeline") msg.good("Initialized pipeline")
msg.divider("Training pipeline") msg.divider("Training pipeline")
train(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr) train_nlp(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)