mirror of https://github.com/explosion/spaCy.git
Allow score_weights to list extra scores
This commit is contained in:
parent
c356e62908
commit
fe1cf7e124
|
@ -162,12 +162,13 @@ def train(
|
|||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
except Exception as e:
|
||||
if output_path is not None:
|
||||
# We don't want to swallow the traceback if we don't have a
|
||||
# specific error.
|
||||
msg.warn(
|
||||
f"Aborting and saving the final best model. "
|
||||
f"Encountered exception: {str(e)}",
|
||||
exits=1,
|
||||
f"Encountered exception: {str(e)}"
|
||||
)
|
||||
else:
|
||||
nlp.to_disk(output_path / "model-final")
|
||||
raise e
|
||||
finally:
|
||||
if output_path is not None:
|
||||
|
@ -207,7 +208,7 @@ def create_evaluation_callback(
|
|||
scores = nlp.evaluate(dev_examples)
|
||||
# Calculate a weighted sum based on score_weights for the main score
|
||||
try:
|
||||
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
|
||||
weighted_score = sum(scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights)
|
||||
except KeyError as e:
|
||||
keys = list(scores.keys())
|
||||
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
|
||||
|
@ -377,7 +378,7 @@ def setup_printer(
|
|||
|
||||
try:
|
||||
scores = [
|
||||
"{0:.2f}".format(float(info["other_scores"][col])) for col in score_cols
|
||||
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0))) for col in score_cols
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
|
@ -403,7 +404,7 @@ def update_meta(
|
|||
) -> None:
|
||||
nlp.meta["performance"] = {}
|
||||
for metric in training["score_weights"]:
|
||||
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
||||
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
|
||||
for pipe_name in nlp.pipe_names:
|
||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||
|
||||
|
|
Loading…
Reference in New Issue