mirror of https://github.com/explosion/spaCy.git
prevent infinite loop, custom warning
This commit is contained in:
parent
6504b7f161
commit
03c58b488c
|
@ -13,6 +13,7 @@ import random
|
|||
from ..gold import GoldCorpus
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
from ..ml import models # don't remove - required to load the built-in architectures
|
||||
|
||||
registry = util.registry
|
||||
|
||||
|
@ -75,7 +76,7 @@ maxout_pieces = 3
|
|||
subword_features = true
|
||||
"""
|
||||
|
||||
|
||||
# TODO: REMOVE ?
|
||||
class PipelineComponent(BaseModel):
|
||||
factory: str
|
||||
model: Model
|
||||
|
@ -83,7 +84,7 @@ class PipelineComponent(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
# TODO: REMOVE ?
|
||||
class ConfigSchema(BaseModel):
|
||||
optimizer: Optional["Optimizer"]
|
||||
|
||||
|
@ -123,7 +124,7 @@ class ConfigSchema(BaseModel):
|
|||
use_gpu=("Use GPU", "option", "g", int),
|
||||
# fmt: on
|
||||
)
|
||||
def train_from_config_cli(
|
||||
def train_cli(
|
||||
train_path,
|
||||
dev_path,
|
||||
config_path,
|
||||
|
@ -132,7 +133,7 @@ def train_from_config_cli(
|
|||
raw_text=None,
|
||||
debug=False,
|
||||
verbose=False,
|
||||
use_gpu=-1
|
||||
use_gpu=-1,
|
||||
):
|
||||
"""
|
||||
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
||||
|
@ -156,7 +157,7 @@ def train_from_config_cli(
|
|||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
train_from_config(
|
||||
train(
|
||||
config_path,
|
||||
{"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
|
@ -165,10 +166,11 @@ def train_from_config_cli(
|
|||
)
|
||||
|
||||
|
||||
def train_from_config(
|
||||
def train(
|
||||
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
|
||||
):
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
util.fix_random_seed(config["training"]["seed"])
|
||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||
|
@ -177,8 +179,8 @@ def train_from_config(
|
|||
config = util.load_config(config_path, create_objects=True)
|
||||
msg.info("Creating nlp from config")
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
optimizer = config["optimizer"]
|
||||
training = config["training"]
|
||||
optimizer = training["optimizer"]
|
||||
limit = training["limit"]
|
||||
msg.info("Loading training corpus")
|
||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
|
@ -246,13 +248,19 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
random.shuffle(train_examples)
|
||||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
||||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"], discard_oversize=cfg["discard_oversize"])
|
||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||
try:
|
||||
first = next(batches)
|
||||
yield first
|
||||
except StopIteration:
|
||||
raise ValueError(Errors.E986)
|
||||
for batch in batches:
|
||||
yield batch
|
||||
epochs_todo -= 1
|
||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||
# will not break.
|
||||
if epochs_todo == 0:
|
||||
if epochs_todo == 0:
|
||||
break
|
||||
|
||||
|
||||
|
@ -366,6 +374,7 @@ def train_while_improving(
|
|||
# Stop if we've exhausted our max steps (if specified)
|
||||
if max_steps and (step * accumulate_gradient) >= max_steps:
|
||||
break
|
||||
step += 1
|
||||
|
||||
|
||||
def subdivide_batch(batch, accumulate_gradient):
|
||||
|
|
|
@ -554,6 +554,8 @@ class Errors(object):
|
|||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
|
||||
E986 = ("Could not create any training batches: check your input. "
|
||||
"Perhaps discard_oversize should be set to False ?")
|
||||
E987 = ("The text of an example training instance is either a Doc or "
|
||||
"a string, but found {type} instead.")
|
||||
E988 = ("Could not parse any training examples. Ensure the data is "
|
||||
|
|
Loading…
Reference in New Issue