Small fixes to train defaults

This commit is contained in:
Matthew Honnibal 2020-06-12 02:22:13 +02:00
parent c0f4a1e43b
commit a1c5b694be
1 changed files with 15 additions and 13 deletions

View File

@ -156,17 +156,18 @@ def train_cli(
msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if output_path is not None and not output_path.exists():
output_path.mkdir()
msg.good(f"Created output directory: {output_path}")
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
msg.warn(
"Output directory is not empty.",
"This can lead to unintended side effects when saving the model. "
"Please use an empty directory or a different path instead. If "
"the specified output path doesn't exist, the directory will be "
"created for you.",
)
if output_path is not None:
if not output_path.exists():
output_path.mkdir()
msg.good(f"Created output directory: {output_path}")
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
msg.warn(
"Output directory is not empty.",
"This can lead to unintended side effects when saving the model. "
"Please use an empty directory or a different path instead. If "
"the specified output path doesn't exist, the directory will be "
"created for you.",
)
if raw_text is not None:
raw_text = list(srsly.read_jsonl(raw_text))
tag_map = {}
@ -210,7 +211,8 @@ def train(
# Read the config first without creating objects, to get to the original nlp_config
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"])
if config["training"]["use_pytorch_for_gpu_memory"]:
if config["training"].get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory()
nlp_config = config["nlp"]
config = util.load_config(config_path, create_objects=True)
@ -374,7 +376,7 @@ def create_train_batches(nlp, corpus, cfg):
train_examples = list(
corpus.train_dataset(
nlp,
noise_level=cfg["noise_level"],
noise_level=0.0, # I think this is deprecated?
orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"],