diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 968a009f6..59b0f2225 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -554,7 +554,30 @@ def train( with nlp.use_params(optimizer.averages): final_model_path = output_path / "model-final" nlp.to_disk(final_model_path) - final_meta = srsly.read_json(output_path / "model-final" / "meta.json") + meta_loc = output_path / "model-final" / "meta.json" + final_meta = srsly.read_json(meta_loc) + final_meta.setdefault("accuracy", {}) + final_meta["accuracy"].update(meta.get("accuracy", {})) + final_meta.setdefault("speed", {}) + final_meta["speed"].setdefault("cpu", None) + final_meta["speed"].setdefault("gpu", None) + # combine cpu and gpu speeds with the base model speeds + if final_meta["speed"]["cpu"] and meta["speed"]["cpu"]: + speed = _get_total_speed([final_meta["speed"]["cpu"], meta["speed"]["cpu"]]) + final_meta["speed"]["cpu"] = speed + if final_meta["speed"]["gpu"] and meta["speed"]["gpu"]: + speed = _get_total_speed([final_meta["speed"]["gpu"], meta["speed"]["gpu"]]) + final_meta["speed"]["gpu"] = speed + # if there were no speeds to update, overwrite with meta + if final_meta["speed"]["cpu"] is None and final_meta["speed"]["gpu"] is None: + final_meta["speed"].update(meta["speed"]) + # note: beam speeds are not combined with the base model + if has_beam_widths: + final_meta.setdefault("beam_accuracy", {}) + final_meta["beam_accuracy"].update(meta.get("beam_accuracy", {})) + final_meta.setdefault("beam_speed", {}) + final_meta["beam_speed"].update(meta.get("beam_speed", {})) + srsly.write_json(meta_loc, final_meta) msg.good("Saved model to output directory", final_model_path) with msg.loading("Creating best model..."): best_model_path = _collate_best_model(final_meta, output_path, best_pipes) @@ -649,11 +672,11 @@ def _get_metrics(component): if component == "parser": return ("las", "uas", "las_per_type", "token_acc") elif component == "tagger": - return ("tags_acc",) + return ("tags_acc", "token_acc") elif component == "ner": - return ("ents_f", "ents_p", "ents_r", "ents_per_type") + return ("ents_f", "ents_p", "ents_r", "ents_per_type", "token_acc") elif component == "textcat": - return ("textcat_score",) + return ("textcat_score", "token_acc") return ("token_acc",) @@ -709,3 +732,12 @@ def _get_progress( if beam_width is not None: result.insert(1, beam_width) return result + + +def _get_total_speed(speeds): + seconds_per_word = 0.0 + for words_per_second in speeds: + if words_per_second is None: + return None + seconds_per_word += 1.0 / words_per_second + return 1.0 / seconds_per_word