mirror of https://github.com/explosion/spaCy.git
234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
import functools
|
|
import logging
|
|
import operator
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import numpy
|
|
import wasabi.tables
|
|
|
|
from .. import util
|
|
from ..errors import Errors
|
|
from ..pipeline import MultiLabel_TextCategorizer, TextCategorizer
|
|
from ..training import Corpus
|
|
from ._util import Arg, Opt, app, import_code, setup_gpu
|
|
|
|
_DEFAULTS = {
|
|
"n_trials": 11,
|
|
"use_gpu": -1,
|
|
"gold_preproc": False,
|
|
}
|
|
|
|
|
|
@app.command(
|
|
"find-threshold",
|
|
context_settings={"allow_extra_args": False, "ignore_unknown_options": True},
|
|
)
|
|
def find_threshold_cli(
|
|
# fmt: off
|
|
model: str = Arg(..., help="Model name or path"),
|
|
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
|
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
|
|
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
|
|
scores_key: str = Arg(..., help="Metric to optimize"),
|
|
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
|
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
|
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
|
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
|
# fmt: on
|
|
):
|
|
"""
|
|
Runs prediction trials for a trained model with varying thresholds to maximize
|
|
the specified metric. The search space for the threshold is traversed linearly
|
|
from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
|
|
(the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
|
|
returns all results).
|
|
|
|
This is applicable only for components whose predictions are influenced by
|
|
thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
|
|
that the full path to the corresponding threshold attribute in the config has to
|
|
be provided.
|
|
|
|
DOCS: https://spacy.io/api/cli#find-threshold
|
|
"""
|
|
if verbose:
|
|
util.logger.setLevel(logging.DEBUG)
|
|
import_code(code_path)
|
|
find_threshold(
|
|
model=model,
|
|
data_path=data_path,
|
|
pipe_name=pipe_name,
|
|
threshold_key=threshold_key,
|
|
scores_key=scores_key,
|
|
n_trials=n_trials,
|
|
use_gpu=use_gpu,
|
|
gold_preproc=gold_preproc,
|
|
silent=False,
|
|
)
|
|
|
|
|
|
def find_threshold(
|
|
model: str,
|
|
data_path: Path,
|
|
pipe_name: str,
|
|
threshold_key: str,
|
|
scores_key: str,
|
|
*,
|
|
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
|
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
|
|
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
|
|
silent: bool = True,
|
|
) -> Tuple[float, float, Dict[float, float]]:
|
|
"""
|
|
Runs prediction trials for models with varying thresholds to maximize the specified metric.
|
|
model (Union[str, Path]): Pipeline to evaluate. Can be a package or a path to a data directory.
|
|
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
|
pipe_name (str): Name of pipe to examine thresholds for.
|
|
threshold_key (str): Key of threshold attribute in component's configuration.
|
|
scores_key (str): Name of score to metric to optimize.
|
|
n_trials (int): Number of trials to determine optimal thresholds.
|
|
use_gpu (int): GPU ID or -1 for CPU.
|
|
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
|
|
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
|
to train/test skew.
|
|
silent (bool): Whether to print non-error-related output to stdout.
|
|
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
|
|
evaluated thresholds.
|
|
"""
|
|
|
|
setup_gpu(use_gpu, silent=silent)
|
|
data_path = util.ensure_path(data_path)
|
|
if not data_path.exists():
|
|
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
|
nlp = util.load_model(model)
|
|
|
|
if pipe_name not in nlp.component_names:
|
|
raise AttributeError(
|
|
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
|
|
)
|
|
pipe = nlp.get_pipe(pipe_name)
|
|
if not hasattr(pipe, "scorer"):
|
|
raise AttributeError(Errors.E1045)
|
|
|
|
if type(pipe) == TextCategorizer:
|
|
wasabi.msg.warn(
|
|
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
|
|
"exclusive classes. All thresholds will yield the same results."
|
|
)
|
|
|
|
if not silent:
|
|
wasabi.msg.info(
|
|
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
|
f"trials."
|
|
)
|
|
|
|
# Load evaluation corpus.
|
|
corpus = Corpus(data_path, gold_preproc=gold_preproc)
|
|
dev_dataset = list(corpus(nlp))
|
|
config_keys = threshold_key.split(".")
|
|
|
|
def set_nested_item(
|
|
config: Dict[str, Any], keys: List[str], value: float
|
|
) -> Dict[str, Any]:
|
|
"""Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
|
|
config (Dict[str, Any]): Configuration dictionary.
|
|
keys (List[Any]): Path to value to set.
|
|
value (float): Value to set.
|
|
RETURNS (Dict[str, Any]): Updated dictionary.
|
|
"""
|
|
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
|
return config
|
|
|
|
def filter_config(
|
|
config: Dict[str, Any], keys: List[str], full_key: str
|
|
) -> Dict[str, Any]:
|
|
"""Filters provided config dictionary so that only the specified keys path remains.
|
|
config (Dict[str, Any]): Configuration dictionary.
|
|
keys (List[Any]): Path to value to set.
|
|
full_key (str): Full user-specified key.
|
|
RETURNS (Dict[str, Any]): Filtered dictionary.
|
|
"""
|
|
if keys[0] not in config:
|
|
wasabi.msg.fail(
|
|
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
|
|
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
|
|
f"{list(config.keys())}",
|
|
exits=1,
|
|
)
|
|
return {
|
|
keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
|
|
if len(keys) > 1
|
|
else config[keys[0]]
|
|
}
|
|
|
|
# Evaluate with varying threshold values.
|
|
scores: Dict[float, float] = {}
|
|
config_keys_full = ["components", pipe_name, *config_keys]
|
|
table_col_widths = (10, 10)
|
|
thresholds = numpy.linspace(0, 1, n_trials)
|
|
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
|
for threshold in thresholds:
|
|
# Reload pipeline with overrides specifying the new threshold.
|
|
nlp = util.load_model(
|
|
model,
|
|
config=set_nested_item(
|
|
filter_config(
|
|
nlp.config, config_keys_full, ".".join(config_keys_full)
|
|
).copy(),
|
|
config_keys_full,
|
|
threshold,
|
|
),
|
|
)
|
|
if hasattr(pipe, "cfg"):
|
|
setattr(
|
|
nlp.get_pipe(pipe_name),
|
|
"cfg",
|
|
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
|
|
)
|
|
|
|
eval_scores = nlp.evaluate(dev_dataset)
|
|
if scores_key not in eval_scores:
|
|
wasabi.msg.fail(
|
|
title=f"Failed to look up score `{scores_key}` in evaluation results.",
|
|
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
|
|
f"available: {list(eval_scores.keys())}",
|
|
exits=1,
|
|
)
|
|
scores[threshold] = eval_scores[scores_key]
|
|
|
|
if not isinstance(scores[threshold], (float, int)):
|
|
wasabi.msg.fail(
|
|
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
|
f"scores.",
|
|
exits=1,
|
|
)
|
|
print(
|
|
wasabi.row(
|
|
[round(threshold, 3), round(scores[threshold], 3)],
|
|
widths=table_col_widths,
|
|
)
|
|
)
|
|
|
|
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
|
|
|
# If all scores are identical, emit warning.
|
|
if len(set(scores.values())) == 1:
|
|
wasabi.msg.warn(
|
|
title="All scores are identical. Verify that all settings are correct.",
|
|
text=""
|
|
if (
|
|
not isinstance(pipe, MultiLabel_TextCategorizer)
|
|
or scores_key in ("cats_macro_f", "cats_micro_f")
|
|
)
|
|
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
|
|
)
|
|
|
|
else:
|
|
if not silent:
|
|
print(
|
|
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
|
|
)
|
|
|
|
return best_threshold, scores[best_threshold], scores
|