mirror of https://github.com/explosion/spaCy.git
Minor bugfixes for train CLI (#5186)
* Omit per_type scores from model-best calculations The addition of per_type scores to the included metrics (#4911) causes errors when they're compared while determining the best model, so omit them for this `max()` comparison. * Add default speed data for interrupted train CLI Add better speed meta defaults so that an interrupted iteration still produces a best model. Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
a04f802099
commit
8d3563f1c4
|
@ -566,6 +566,9 @@ def train(
|
||||||
final_meta.setdefault("speed", {})
|
final_meta.setdefault("speed", {})
|
||||||
final_meta["speed"].setdefault("cpu", None)
|
final_meta["speed"].setdefault("cpu", None)
|
||||||
final_meta["speed"].setdefault("gpu", None)
|
final_meta["speed"].setdefault("gpu", None)
|
||||||
|
meta.setdefault("speed", {})
|
||||||
|
meta["speed"].setdefault("cpu", None)
|
||||||
|
meta["speed"].setdefault("gpu", None)
|
||||||
# combine cpu and gpu speeds with the base model speeds
|
# combine cpu and gpu speeds with the base model speeds
|
||||||
if final_meta["speed"]["cpu"] and meta["speed"]["cpu"]:
|
if final_meta["speed"]["cpu"] and meta["speed"]["cpu"]:
|
||||||
speed = _get_total_speed(
|
speed = _get_total_speed(
|
||||||
|
@ -673,6 +676,8 @@ def _find_best(experiment_dir, component):
|
||||||
if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final":
|
if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final":
|
||||||
accs = srsly.read_json(epoch_model / "accuracy.json")
|
accs = srsly.read_json(epoch_model / "accuracy.json")
|
||||||
scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)]
|
scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)]
|
||||||
|
# remove per_type dicts from score list for max() comparison
|
||||||
|
scores = [score for score in scores if isinstance(score, float)]
|
||||||
accuracies.append((scores, epoch_model))
|
accuracies.append((scores, epoch_model))
|
||||||
if accuracies:
|
if accuracies:
|
||||||
return max(accuracies)[1]
|
return max(accuracies)[1]
|
||||||
|
|
Loading…
Reference in New Issue