mirror of https://github.com/explosion/spaCy.git
Start updating train script
This commit is contained in:
parent
39b178999c
commit
b5556093e2
|
@ -16,6 +16,7 @@ from ._util import import_code, get_sourced_components
|
|||
from ..language import Language
|
||||
from .. import util
|
||||
from ..training.example import Example
|
||||
from ..training.initialize import must_initialize, init_pipeline
|
||||
from ..errors import Errors
|
||||
from ..util import dot_to_object
|
||||
|
||||
|
@ -31,8 +32,6 @@ def train_cli(
|
|||
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||
resume: bool = Opt(False, "--resume", "-R", help="Resume training"),
|
||||
dave_path: Optional[Path] = Opt(None, "--dave", "-D", help="etc etc"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
|
@ -53,38 +52,37 @@ def train_cli(
|
|||
verify_cli_args(config_path, output_path)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
if prepared is None:
|
||||
prepare(config_path, output_path / "prepared", config_overrides=overrides)
|
||||
train(
|
||||
config_path,
|
||||
output_path=output_path,
|
||||
dave_path=dave_path,
|
||||
config_overrides=overrides,
|
||||
use_gpu=use_gpu,
|
||||
resume_training=resume,
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
output_path: Path,
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
use_gpu: int = -1,
|
||||
resume_training: bool = False,
|
||||
) -> None:
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
require_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config and nlp from: {config_path}")
|
||||
# TODO: The details of this will change
|
||||
dave_path = output_path / "dave"
|
||||
config_path = dave_path / "config.cfg"
|
||||
with show_validation_error(config_path):
|
||||
config = fill_config_etc_etc(config_path)
|
||||
nlp = make_and_load_nlp_etc_etc(config, dave_path)
|
||||
optimizer, train_corpus, dev_corpus, score_weights, T_cfg = resolve_more_things_etc_etc(config)
|
||||
config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
)
|
||||
if output_path is None:
|
||||
nlp = init_pipeline(config)
|
||||
else:
|
||||
init_path = output_path / "model-initial"
|
||||
if must_reinitialize(config, init_path):
|
||||
nlp = init_pipeline(config)
|
||||
nlp.to_disk(init_path)
|
||||
else:
|
||||
nlp = spacy.load(output_path / "model-initial")
|
||||
msg.info("Start training")
|
||||
train(nlp, config, output_path)
|
||||
|
||||
|
||||
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = nlp.config
|
||||
T_cfg = config["training"]
|
||||
score_weights = T_cfg["score_weights"]
|
||||
optimizer = T_cfg["optimizer"]
|
||||
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||
batcher = T_cfg["batcher"]
|
||||
|
||||
training_step_iterator = train_while_improving(
|
||||
nlp,
|
||||
optimizer,
|
||||
|
@ -142,6 +140,7 @@ def train(
|
|||
msg.good(f"Saved pipeline to output directory {final_model_path}")
|
||||
|
||||
|
||||
|
||||
def add_vectors(nlp: Language, vectors: str) -> None:
|
||||
title = f"Config validation error for vectors {vectors}"
|
||||
desc = (
|
||||
|
|
Loading…
Reference in New Issue