mirror of https://github.com/explosion/spaCy.git
prevent infinite loop, custom warning
This commit is contained in:
parent
6504b7f161
commit
03c58b488c
|
@ -13,6 +13,7 @@ import random
|
||||||
from ..gold import GoldCorpus
|
from ..gold import GoldCorpus
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..ml import models # don't remove - required to load the built-in architectures
|
||||||
|
|
||||||
registry = util.registry
|
registry = util.registry
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: REMOVE ?
|
||||||
class PipelineComponent(BaseModel):
|
class PipelineComponent(BaseModel):
|
||||||
factory: str
|
factory: str
|
||||||
model: Model
|
model: Model
|
||||||
|
@ -83,7 +84,7 @@ class PipelineComponent(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
# TODO: REMOVE ?
|
||||||
class ConfigSchema(BaseModel):
|
class ConfigSchema(BaseModel):
|
||||||
optimizer: Optional["Optimizer"]
|
optimizer: Optional["Optimizer"]
|
||||||
|
|
||||||
|
@ -123,7 +124,7 @@ class ConfigSchema(BaseModel):
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
)
|
)
|
||||||
def train_from_config_cli(
|
def train_cli(
|
||||||
train_path,
|
train_path,
|
||||||
dev_path,
|
dev_path,
|
||||||
config_path,
|
config_path,
|
||||||
|
@ -132,7 +133,7 @@ def train_from_config_cli(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
debug=False,
|
debug=False,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
use_gpu=-1
|
use_gpu=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
||||||
|
@ -156,7 +157,7 @@ def train_from_config_cli(
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
|
|
||||||
train_from_config(
|
train(
|
||||||
config_path,
|
config_path,
|
||||||
{"train": train_path, "dev": dev_path},
|
{"train": train_path, "dev": dev_path},
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
@ -165,10 +166,11 @@ def train_from_config_cli(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_from_config(
|
def train(
|
||||||
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
|
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
|
||||||
):
|
):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["training"]["seed"])
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||||
|
@ -177,8 +179,8 @@ def train_from_config(
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
optimizer = config["optimizer"]
|
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
|
optimizer = training["optimizer"]
|
||||||
limit = training["limit"]
|
limit = training["limit"]
|
||||||
msg.info("Loading training corpus")
|
msg.info("Loading training corpus")
|
||||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||||
|
@ -246,13 +248,19 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
random.shuffle(train_examples)
|
random.shuffle(train_examples)
|
||||||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"], discard_oversize=cfg["discard_oversize"])
|
||||||
|
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||||
|
try:
|
||||||
|
first = next(batches)
|
||||||
|
yield first
|
||||||
|
except StopIteration:
|
||||||
|
raise ValueError(Errors.E986)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
yield batch
|
yield batch
|
||||||
epochs_todo -= 1
|
epochs_todo -= 1
|
||||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||||
# will not break.
|
# will not break.
|
||||||
if epochs_todo == 0:
|
if epochs_todo == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@ -366,6 +374,7 @@ def train_while_improving(
|
||||||
# Stop if we've exhausted our max steps (if specified)
|
# Stop if we've exhausted our max steps (if specified)
|
||||||
if max_steps and (step * accumulate_gradient) >= max_steps:
|
if max_steps and (step * accumulate_gradient) >= max_steps:
|
||||||
break
|
break
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(batch, accumulate_gradient):
|
def subdivide_batch(batch, accumulate_gradient):
|
||||||
|
|
|
@ -554,6 +554,8 @@ class Errors(object):
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
|
||||||
|
E986 = ("Could not create any training batches: check your input. "
|
||||||
|
"Perhaps discard_oversize should be set to False ?")
|
||||||
E987 = ("The text of an example training instance is either a Doc or "
|
E987 = ("The text of an example training instance is either a Doc or "
|
||||||
"a string, but found {type} instead.")
|
"a string, but found {type} instead.")
|
||||||
E988 = ("Could not parse any training examples. Ensure the data is "
|
E988 = ("Could not parse any training examples. Ensure the data is "
|
||||||
|
|
Loading…
Reference in New Issue