mirror of https://github.com/explosion/spaCy.git
Move core training logic in CLI into standalone function (#9398)
This commit is contained in:
parent
2a7e327310
commit
5003a9c3c7
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import typer
|
import typer
|
||||||
|
@ -7,7 +7,7 @@ import sys
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, setup_gpu
|
from ._util import import_code, setup_gpu
|
||||||
from ..training.loop import train
|
from ..training.loop import train as train_nlp
|
||||||
from ..training.initialize import init_nlp
|
from ..training.initialize import init_nlp
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
@ -40,6 +40,18 @@ def train_cli(
|
||||||
DOCS: https://spacy.io/api/cli#train
|
DOCS: https://spacy.io/api/cli#train
|
||||||
"""
|
"""
|
||||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
overrides = parse_config_overrides(ctx.args)
|
||||||
|
import_code(code_path)
|
||||||
|
train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
config_path: Path,
|
||||||
|
output_path: Optional[Path] = None,
|
||||||
|
*,
|
||||||
|
use_gpu: int = -1,
|
||||||
|
overrides: Dict[str, Any] = util.SimpleFrozenDict(),
|
||||||
|
):
|
||||||
# Make sure all files and paths exists if they are needed
|
# Make sure all files and paths exists if they are needed
|
||||||
if not config_path or (str(config_path) != "-" and not config_path.exists()):
|
if not config_path or (str(config_path) != "-" and not config_path.exists()):
|
||||||
msg.fail("Config file not found", config_path, exits=1)
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
|
@ -50,8 +62,6 @@ def train_cli(
|
||||||
output_path.mkdir(parents=True)
|
output_path.mkdir(parents=True)
|
||||||
msg.good(f"Created output directory: {output_path}")
|
msg.good(f"Created output directory: {output_path}")
|
||||||
msg.info(f"Saving to output directory: {output_path}")
|
msg.info(f"Saving to output directory: {output_path}")
|
||||||
overrides = parse_config_overrides(ctx.args)
|
|
||||||
import_code(code_path)
|
|
||||||
setup_gpu(use_gpu)
|
setup_gpu(use_gpu)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(config_path, overrides=overrides, interpolate=False)
|
config = util.load_config(config_path, overrides=overrides, interpolate=False)
|
||||||
|
@ -60,4 +70,4 @@ def train_cli(
|
||||||
nlp = init_nlp(config, use_gpu=use_gpu)
|
nlp = init_nlp(config, use_gpu=use_gpu)
|
||||||
msg.good("Initialized pipeline")
|
msg.good("Initialized pipeline")
|
||||||
msg.divider("Training pipeline")
|
msg.divider("Training pipeline")
|
||||||
train(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)
|
train_nlp(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)
|
||||||
|
|
Loading…
Reference in New Issue