prevent infinite loop, custom warning

This commit is contained in:
svlandeg 2020-06-03 10:00:21 +02:00
parent 6504b7f161
commit 03c58b488c
2 changed files with 20 additions and 9 deletions

View File

@ -13,6 +13,7 @@ import random
from ..gold import GoldCorpus
from .. import util
from ..errors import Errors
from ..ml import models # don't remove - required to load the built-in architectures
registry = util.registry
@ -75,7 +76,7 @@ maxout_pieces = 3
subword_features = true
"""
# TODO: REMOVE ?
class PipelineComponent(BaseModel):
factory: str
model: Model
@ -83,7 +84,7 @@ class PipelineComponent(BaseModel):
class Config:
arbitrary_types_allowed = True
# TODO: REMOVE ?
class ConfigSchema(BaseModel):
optimizer: Optional["Optimizer"]
@ -123,7 +124,7 @@ class ConfigSchema(BaseModel):
use_gpu=("Use GPU", "option", "g", int),
# fmt: on
)
def train_from_config_cli(
def train_cli(
train_path,
dev_path,
config_path,
@ -132,7 +133,7 @@ def train_from_config_cli(
raw_text=None,
debug=False,
verbose=False,
use_gpu=-1
use_gpu=-1,
):
"""
Train or update a spaCy model. Requires data to be formatted in spaCy's
@ -156,7 +157,7 @@ def train_from_config_cli(
else:
msg.info("Using CPU")
train_from_config(
train(
config_path,
{"train": train_path, "dev": dev_path},
output_path=output_path,
@ -165,10 +166,11 @@ def train_from_config_cli(
)
def train_from_config(
def train(
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
):
msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"])
if config["training"]["use_pytorch_for_gpu_memory"]:
@ -177,8 +179,8 @@ def train_from_config(
config = util.load_config(config_path, create_objects=True)
msg.info("Creating nlp from config")
nlp = util.load_model_from_config(nlp_config)
optimizer = config["optimizer"]
training = config["training"]
optimizer = training["optimizer"]
limit = training["limit"]
msg.info("Loading training corpus")
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
@ -246,13 +248,19 @@ def create_train_batches(nlp, corpus, cfg):
if len(train_examples) == 0:
raise ValueError(Errors.E988)
random.shuffle(train_examples)
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"], discard_oversize=cfg["discard_oversize"])
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
try:
first = next(batches)
yield first
except StopIteration:
raise ValueError(Errors.E986)
for batch in batches:
yield batch
epochs_todo -= 1
# We intentionally compare exactly to 0 here, so that max_epochs < 1
# will not break.
if epochs_todo == 0:
if epochs_todo == 0:
break
@ -366,6 +374,7 @@ def train_while_improving(
# Stop if we've exhausted our max steps (if specified)
if max_steps and (step * accumulate_gradient) >= max_steps:
break
step += 1
def subdivide_batch(batch, accumulate_gradient):

View File

@ -554,6 +554,8 @@ class Errors(object):
# TODO: fix numbering after merging develop into master
E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?")
E987 = ("The text of an example training instance is either a Doc or "
"a string, but found {type} instead.")
E988 = ("Could not parse any training examples. Ensure the data is "