Allow score_weights to list extra scores

This commit is contained in:
Matthew Honnibal 2020-08-23 18:31:30 +02:00
parent c356e62908
commit fe1cf7e124
1 changed files with 8 additions and 7 deletions

View File

@ -162,12 +162,13 @@ def train(
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
except Exception as e: except Exception as e:
if output_path is not None: if output_path is not None:
# We don't want to swallow the traceback if we don't have a
# specific error.
msg.warn( msg.warn(
f"Aborting and saving the final best model. " f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}", f"Encountered exception: {str(e)}"
exits=1,
) )
else: nlp.to_disk(output_path / "model-final")
raise e raise e
finally: finally:
if output_path is not None: if output_path is not None:
@ -207,7 +208,7 @@ def create_evaluation_callback(
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
# Calculate a weighted sum based on score_weights for the main score # Calculate a weighted sum based on score_weights for the main score
try: try:
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights) weighted_score = sum(scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights)
except KeyError as e: except KeyError as e:
keys = list(scores.keys()) keys = list(scores.keys())
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys) err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
@ -377,7 +378,7 @@ def setup_printer(
try: try:
scores = [ scores = [
"{0:.2f}".format(float(info["other_scores"][col])) for col in score_cols "{0:.2f}".format(float(info["other_scores"].get(col, 0.0))) for col in score_cols
] ]
except KeyError as e: except KeyError as e:
raise KeyError( raise KeyError(
@ -403,7 +404,7 @@ def update_meta(
) -> None: ) -> None:
nlp.meta["performance"] = {} nlp.meta["performance"] = {}
for metric in training["score_weights"]: for metric in training["score_weights"]:
nlp.meta["performance"][metric] = info["other_scores"][metric] nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
for pipe_name in nlp.pipe_names: for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]