Fix combined scores and update test

This commit is contained in:
Ines Montani 2020-09-24 10:42:47 +02:00
parent ae51f580c1
commit 4bbe41f017
3 changed files with 12 additions and 9 deletions

View File

@ -251,11 +251,8 @@ class Language:
# We're merging the existing score weights back into the combined
# weights to make sure we're preserving custom settings in the config
# but also reflect updates (e.g. new components added)
prev_score_weights = self._config["training"].get("score_weights", {})
combined_score_weights = combine_score_weights(score_weights)
combined_score_weights.update(prev_score_weights)
# Combine the scores a second time to normalize them
combined_score_weights = combine_score_weights([combined_score_weights])
prev_weights = self._config["training"].get("score_weights", {})
combined_score_weights = combine_score_weights(score_weights, prev_weights)
self._config["training"]["score_weights"] = combined_score_weights
if not srsly.is_json_serializable(self._config):
raise ValueError(Errors.E961.format(config=self._config))

View File

@ -378,14 +378,14 @@ def test_language_factories_scores():
config["training"]["score_weights"]["b3"] = 1.0
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": 0.0, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.59}
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34}
assert score_weights == expected
# Test with null values
config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = None
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": None, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.58} # rounding :(
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35}
assert score_weights == expected

View File

@ -1202,11 +1202,16 @@ def get_arg_names(func: Callable) -> List[str]:
return list(set([*argspec.args, *argspec.kwonlyargs]))
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
def combine_score_weights(
weights: List[Dict[str, float]],
overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
) -> Dict[str, float]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
weights (List[dict]): The weights defined by the components.
overrides (Dict[str, Optional[Union[float, int]]]): Existing scores that
should be preserved.
RETURNS (Dict[str, float]): The combined and normalized weights.
"""
# We first need to extract all None/null values for score weights that
@ -1216,6 +1221,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
for w_dict in weights:
filtered_weights = {}
for key, value in w_dict.items():
value = overrides.get(key, value)
if value is None:
result[key] = None
else:
@ -1227,7 +1233,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
weight = round(value / total / len(weights), 2)
weight = round(value / total / len(all_weights), 2)
result[key] = result.get(key, 0.0) + weight
return result