From 5003a9c3c7299830bbdf73f77bed6e4076428a81 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Mon, 11 Oct 2021 10:56:14 +0200 Subject: [PATCH] Move core training logic in CLI into standalone function (#9398) --- spacy/cli/train.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 9fd87dbc1..664fc2aaf 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict, Any from pathlib import Path from wasabi import msg import typer @@ -7,7 +7,7 @@ import sys from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import import_code, setup_gpu -from ..training.loop import train +from ..training.loop import train as train_nlp from ..training.initialize import init_nlp from .. import util @@ -40,6 +40,18 @@ def train_cli( DOCS: https://spacy.io/api/cli#train """ util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) + overrides = parse_config_overrides(ctx.args) + import_code(code_path) + train(config_path, output_path, use_gpu=use_gpu, overrides=overrides) + + +def train( + config_path: Path, + output_path: Optional[Path] = None, + *, + use_gpu: int = -1, + overrides: Dict[str, Any] = util.SimpleFrozenDict(), +): # Make sure all files and paths exists if they are needed if not config_path or (str(config_path) != "-" and not config_path.exists()): msg.fail("Config file not found", config_path, exits=1) @@ -50,8 +62,6 @@ def train_cli( output_path.mkdir(parents=True) msg.good(f"Created output directory: {output_path}") msg.info(f"Saving to output directory: {output_path}") - overrides = parse_config_overrides(ctx.args) - import_code(code_path) setup_gpu(use_gpu) with show_validation_error(config_path): config = util.load_config(config_path, overrides=overrides, interpolate=False) @@ -60,4 +70,4 @@ def train_cli( nlp = init_nlp(config, use_gpu=use_gpu) msg.good("Initialized pipeline") msg.divider("Training pipeline") - train(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr) + train_nlp(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)