mirror of https://github.com/explosion/spaCy.git
Small fixes to train defaults
This commit is contained in:
parent
c0f4a1e43b
commit
a1c5b694be
|
@ -156,7 +156,8 @@ def train_cli(
|
||||||
msg.fail("Training data not found", train_path, exits=1)
|
msg.fail("Training data not found", train_path, exits=1)
|
||||||
if not dev_path or not dev_path.exists():
|
if not dev_path or not dev_path.exists():
|
||||||
msg.fail("Development data not found", dev_path, exits=1)
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
if output_path is not None and not output_path.exists():
|
if output_path is not None:
|
||||||
|
if not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
msg.good(f"Created output directory: {output_path}")
|
msg.good(f"Created output directory: {output_path}")
|
||||||
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
elif output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
||||||
|
@ -210,7 +211,8 @@ def train(
|
||||||
# Read the config first without creating objects, to get to the original nlp_config
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["training"]["seed"])
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
if config["training"].get("use_pytorch_for_gpu_memory"):
|
||||||
|
# It feels kind of weird to not have a default for this.
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
nlp_config = config["nlp"]
|
nlp_config = config["nlp"]
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
|
@ -374,7 +376,7 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
train_examples = list(
|
train_examples = list(
|
||||||
corpus.train_dataset(
|
corpus.train_dataset(
|
||||||
nlp,
|
nlp,
|
||||||
noise_level=cfg["noise_level"],
|
noise_level=0.0, # I think this is deprecated?
|
||||||
orth_variant_level=cfg["orth_variant_level"],
|
orth_variant_level=cfg["orth_variant_level"],
|
||||||
gold_preproc=cfg["gold_preproc"],
|
gold_preproc=cfg["gold_preproc"],
|
||||||
max_length=cfg["max_length"],
|
max_length=cfg["max_length"],
|
||||||
|
|
Loading…
Reference in New Issue