diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 5cf0f5f6f..ebac339ae 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -226,7 +226,7 @@ def train( msg.row(["-" * width for width in row_settings["widths"]], **row_settings) try: iter_since_best = 0 - best_score = 0. + best_score = 0.0 for i in range(n_iter): train_docs = corpus.train_docs( nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0 @@ -335,8 +335,8 @@ def train( gpu_wps=gpu_wps, ) msg.row(progress, **row_settings) - # early stopping if early_stopping_iter is not None: + # Early stopping current_score = _score_for_model(meta) if current_score < best_score: iter_since_best += 1 @@ -344,8 +344,14 @@ def train( iter_since_best = 0 best_score = current_score if iter_since_best >= early_stopping_iter: - msg.text("Early stopping, best iteration is: {}".format(i-iter_since_best)) - msg.text("Best score = {}; Final iteration score = {}".format(best_score, current_score)) + msg.text( + "Early stopping, best iteration " + "is: {}".format(i - iter_since_best) + ) + msg.text( + "Best score = {}; Final iteration " + "score = {}".format(best_score, current_score) + ) break finally: with nlp.use_params(optimizer.averages): @@ -356,19 +362,21 @@ def train( best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names) msg.good("Created best model", best_model_path) + def _score_for_model(meta): """ Returns mean score between tasks in pipeline that can be used for early stopping. """ mean_acc = list() - pipes = meta['pipeline'] - acc = meta['accuracy'] - if 'tagger' in pipes: - mean_acc.append(acc['tags_acc']) - if 'parser' in pipes: - mean_acc.append((acc['uas']+acc['las']) / 2) - if 'ner' in pipes: - mean_acc.append((acc['ents_p']+acc['ents_r']+acc['ents_f']) / 3) + pipes = meta["pipeline"] + acc = meta["accuracy"] + if "tagger" in pipes: + mean_acc.append(acc["tags_acc"]) + if "parser" in pipes: + mean_acc.append((acc["uas"] + acc["las"]) / 2) + if "ner" in pipes: + mean_acc.append((acc["ents_p"] + acc["ents_r"] + acc["ents_f"]) / 3) return sum(mean_acc) / len(mean_acc) + @contextlib.contextmanager def _create_progress_bar(total): if int(os.environ.get("LOG_FRIENDLY", 0)):