Minor bugfixes for train CLI (#5186)

* Omit per_type scores from model-best calculations

The addition of per_type scores to the included metrics (#4911) causes
errors when they're compared while determining the best model, so omit
them for this `max()` comparison.

* Add default speed data for interrupted train CLI

Add better speed meta defaults so that an interrupted iteration still
produces a best model.

Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
adrianeboyd 2020-03-26 10:46:50 +01:00 committed by GitHub
parent a04f802099
commit 8d3563f1c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -566,6 +566,9 @@ def train(
final_meta.setdefault("speed", {}) final_meta.setdefault("speed", {})
final_meta["speed"].setdefault("cpu", None) final_meta["speed"].setdefault("cpu", None)
final_meta["speed"].setdefault("gpu", None) final_meta["speed"].setdefault("gpu", None)
meta.setdefault("speed", {})
meta["speed"].setdefault("cpu", None)
meta["speed"].setdefault("gpu", None)
# combine cpu and gpu speeds with the base model speeds # combine cpu and gpu speeds with the base model speeds
if final_meta["speed"]["cpu"] and meta["speed"]["cpu"]: if final_meta["speed"]["cpu"] and meta["speed"]["cpu"]:
speed = _get_total_speed( speed = _get_total_speed(
@ -673,6 +676,8 @@ def _find_best(experiment_dir, component):
if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final": if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final":
accs = srsly.read_json(epoch_model / "accuracy.json") accs = srsly.read_json(epoch_model / "accuracy.json")
scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)] scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)]
# remove per_type dicts from score list for max() comparison
scores = [score for score in scores if isinstance(score, float)]
accuracies.append((scores, epoch_model)) accuracies.append((scores, epoch_model))
if accuracies: if accuracies:
return max(accuracies)[1] return max(accuracies)[1]