Small fixes to train defaults

This commit is contained in:
Matthew Honnibal 2020-06-12 02:22:13 +02:00
parent c0f4a1e43b
commit a1c5b694be
1 changed files with 15 additions and 13 deletions

View File

@ -156,7 +156,8 @@ def train_cli(
msg.fail("Training data not found", train_path, exits=1) msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists(): if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1) msg.fail("Development data not found", dev_path, exits=1)
if output_path is not None and not output_path.exists(): if output_path is not None:
if not output_path.exists():
output_path.mkdir() output_path.mkdir()
msg.good(f"Created output directory: {output_path}") msg.good(f"Created output directory: {output_path}")
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]: elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
@ -210,7 +211,8 @@ def train(
# Read the config first without creating objects, to get to the original nlp_config # Read the config first without creating objects, to get to the original nlp_config
config = util.load_config(config_path, create_objects=False) config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"]) util.fix_random_seed(config["training"]["seed"])
if config["training"]["use_pytorch_for_gpu_memory"]: if config["training"].get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
nlp_config = config["nlp"] nlp_config = config["nlp"]
config = util.load_config(config_path, create_objects=True) config = util.load_config(config_path, create_objects=True)
@ -374,7 +376,7 @@ def create_train_batches(nlp, corpus, cfg):
train_examples = list( train_examples = list(
corpus.train_dataset( corpus.train_dataset(
nlp, nlp,
noise_level=cfg["noise_level"], noise_level=0.0, # I think this is deprecated?
orth_variant_level=cfg["orth_variant_level"], orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"], gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"], max_length=cfg["max_length"],