Dont hard-code for 'corpora' name

This commit is contained in:
Matthew Honnibal 2020-09-28 03:06:33 +02:00
parent a023cf3ecc
commit 3a0a3b8db6
1 changed files with 3 additions and 5 deletions

View File

@ -77,12 +77,10 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
# Create iterator, which yields out info after each optimization step. # Create iterator, which yields out info after each optimization step.
config = nlp.config.interpolate() config = nlp.config.interpolate()
T = registry.resolve(config["training"], schema=ConfigSchemaTraining) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
optimizer T["optimizer"] optimizer T["optimizer"]
score_weights = T["score_weights"] score_weights = T["score_weights"]
# TODO: This might not be called corpora
corpora = registry.resolve(config["corpora"], schema=ConfigSchemaCorpora)
train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
batcher = T["batcher"] batcher = T["batcher"]
train_logger = T["logger"] train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"]) before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
@ -101,7 +99,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
patience=T["patience"], patience=T["patience"],
max_steps=T["max_steps"], max_steps=T["max_steps"],
eval_frequency=T["eval_frequency"], eval_frequency=T["eval_frequency"],
raw_text=None, raw_text=raw_text,
exclude=frozen_components, exclude=frozen_components,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")