mirror of https://github.com/explosion/spaCy.git
Add a `spacy benchmark speed` subcommand (#11902)
* Add a `spacy evaluate speed` subcommand This subcommand reports the mean batch performance of a model on a data set with a 95% confidence interval. For reliability, it first performs some warmup rounds. Then it will measure performance on batches with randomly shuffled documents. To avoid having too many spaCy commands, `speed` is a subcommand of `evaluate` and accuracy evaluation is moved to its own `evaluate accuracy` subcommand. * Fix import cycle * Restore `spacy evaluate`, make `spacy benchmark speed` an alias * Add documentation for `spacy benchmark` * CREATES -> PRINTS * WPS -> words/s * Disable formatting of benchmark speed arguments * Fail with an error message when trying to speed bench empty corpus * Make it clearer that `benchmark accuracy` is a replacement for `evaluate` * Fix docstring webpage reference * tests: check `evaluate` output against `benchmark accuracy`
This commit is contained in:
parent
8e558095a1
commit
319eb508b5
|
@ -4,6 +4,7 @@ from ._util import app, setup_cli # noqa: F401
|
||||||
|
|
||||||
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
|
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
|
||||||
# are registered automatically and won't have to be imported here.
|
# are registered automatically and won't have to be imported here.
|
||||||
|
from .benchmark_speed import benchmark_speed_cli # noqa: F401
|
||||||
from .download import download # noqa: F401
|
from .download import download # noqa: F401
|
||||||
from .info import info # noqa: F401
|
from .info import info # noqa: F401
|
||||||
from .package import package # noqa: F401
|
from .package import package # noqa: F401
|
||||||
|
|
|
@ -46,6 +46,7 @@ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
|
||||||
commands to check and validate your config files, training and evaluation data,
|
commands to check and validate your config files, training and evaluation data,
|
||||||
and custom model implementations.
|
and custom model implementations.
|
||||||
"""
|
"""
|
||||||
|
BENCHMARK_HELP = """Commands for benchmarking pipelines."""
|
||||||
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
||||||
|
|
||||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||||
|
@ -54,12 +55,14 @@ Arg = typer.Argument
|
||||||
Opt = typer.Option
|
Opt = typer.Option
|
||||||
|
|
||||||
app = typer.Typer(name=NAME, help=HELP)
|
app = typer.Typer(name=NAME, help=HELP)
|
||||||
|
benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True)
|
||||||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||||
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
||||||
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
||||||
|
|
||||||
app.add_typer(project_cli)
|
app.add_typer(project_cli)
|
||||||
app.add_typer(debug_cli)
|
app.add_typer(debug_cli)
|
||||||
|
app.add_typer(benchmark_cli)
|
||||||
app.add_typer(init_cli)
|
app.add_typer(init_cli)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
from typing import Iterable, List, Optional
|
||||||
|
import random
|
||||||
|
from itertools import islice
|
||||||
|
import numpy
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import typer
|
||||||
|
from wasabi import msg
|
||||||
|
|
||||||
|
from .. import util
|
||||||
|
from ..language import Language
|
||||||
|
from ..tokens import Doc
|
||||||
|
from ..training import Corpus
|
||||||
|
from ._util import Arg, Opt, benchmark_cli, setup_gpu
|
||||||
|
|
||||||
|
|
||||||
|
@benchmark_cli.command(
|
||||||
|
"speed",
|
||||||
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def benchmark_speed_cli(
|
||||||
|
# fmt: off
|
||||||
|
ctx: typer.Context,
|
||||||
|
model: str = Arg(..., help="Model name or path"),
|
||||||
|
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
||||||
|
batch_size: Optional[int] = Opt(None, "--batch-size", "-b", min=1, help="Override the pipeline batch size"),
|
||||||
|
no_shuffle: bool = Opt(False, "--no-shuffle", help="Do not shuffle benchmark data"),
|
||||||
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||||
|
n_batches: int = Opt(50, "--batches", help="Minimum number of batches to benchmark", min=30,),
|
||||||
|
warmup_epochs: int = Opt(3, "--warmup", "-w", min=0, help="Number of iterations over the data for warmup"),
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Benchmark a pipeline. Expects a loadable spaCy pipeline and benchmark
|
||||||
|
data in the binary .spacy format.
|
||||||
|
"""
|
||||||
|
setup_gpu(use_gpu=use_gpu, silent=False)
|
||||||
|
|
||||||
|
nlp = util.load_model(model)
|
||||||
|
batch_size = batch_size if batch_size is not None else nlp.batch_size
|
||||||
|
corpus = Corpus(data_path)
|
||||||
|
docs = [eg.predicted for eg in corpus(nlp)]
|
||||||
|
|
||||||
|
if len(docs) == 0:
|
||||||
|
msg.fail("Cannot benchmark speed using an empty corpus.", exits=1)
|
||||||
|
|
||||||
|
print(f"Warming up for {warmup_epochs} epochs...")
|
||||||
|
warmup(nlp, docs, warmup_epochs, batch_size)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Benchmarking {n_batches} batches...")
|
||||||
|
wps = benchmark(nlp, docs, n_batches, batch_size, not no_shuffle)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print_outliers(wps)
|
||||||
|
print_mean_with_ci(wps)
|
||||||
|
|
||||||
|
|
||||||
|
# Lowercased, behaves as a context manager function.
|
||||||
|
class time_context:
|
||||||
|
"""Register the running time of a context."""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = time.perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.elapsed = time.perf_counter() - self.start
|
||||||
|
|
||||||
|
|
||||||
|
class Quartiles:
|
||||||
|
"""Calculate the q1, q2, q3 quartiles and the inter-quartile range (iqr)
|
||||||
|
of a sample."""
|
||||||
|
|
||||||
|
q1: float
|
||||||
|
q2: float
|
||||||
|
q3: float
|
||||||
|
iqr: float
|
||||||
|
|
||||||
|
def __init__(self, sample: numpy.ndarray) -> None:
|
||||||
|
self.q1 = numpy.quantile(sample, 0.25)
|
||||||
|
self.q2 = numpy.quantile(sample, 0.5)
|
||||||
|
self.q3 = numpy.quantile(sample, 0.75)
|
||||||
|
self.iqr = self.q3 - self.q1
|
||||||
|
|
||||||
|
|
||||||
|
def annotate(
|
||||||
|
nlp: Language, docs: List[Doc], batch_size: Optional[int]
|
||||||
|
) -> numpy.ndarray:
|
||||||
|
docs = nlp.pipe(tqdm(docs, unit="doc"), batch_size=batch_size)
|
||||||
|
wps = []
|
||||||
|
while True:
|
||||||
|
with time_context() as elapsed:
|
||||||
|
batch_docs = list(
|
||||||
|
islice(docs, batch_size if batch_size else nlp.batch_size)
|
||||||
|
)
|
||||||
|
if len(batch_docs) == 0:
|
||||||
|
break
|
||||||
|
n_tokens = count_tokens(batch_docs)
|
||||||
|
wps.append(n_tokens / elapsed.elapsed)
|
||||||
|
|
||||||
|
return numpy.array(wps)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
nlp: Language,
|
||||||
|
docs: List[Doc],
|
||||||
|
n_batches: int,
|
||||||
|
batch_size: int,
|
||||||
|
shuffle: bool,
|
||||||
|
) -> numpy.ndarray:
|
||||||
|
if shuffle:
|
||||||
|
bench_docs = [
|
||||||
|
nlp.make_doc(random.choice(docs).text)
|
||||||
|
for _ in range(n_batches * batch_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
bench_docs = [
|
||||||
|
nlp.make_doc(docs[i % len(docs)].text)
|
||||||
|
for i in range(n_batches * batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
return annotate(nlp, bench_docs, batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def bootstrap(x, statistic=numpy.mean, iterations=10000) -> numpy.ndarray:
|
||||||
|
"""Apply a statistic to repeated random samples of an array."""
|
||||||
|
return numpy.fromiter(
|
||||||
|
(
|
||||||
|
statistic(numpy.random.choice(x, len(x), replace=True))
|
||||||
|
for _ in range(iterations)
|
||||||
|
),
|
||||||
|
numpy.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(docs: Iterable[Doc]) -> int:
|
||||||
|
return sum(len(doc) for doc in docs)
|
||||||
|
|
||||||
|
|
||||||
|
def print_mean_with_ci(sample: numpy.ndarray):
|
||||||
|
mean = numpy.mean(sample)
|
||||||
|
bootstrap_means = bootstrap(sample)
|
||||||
|
bootstrap_means.sort()
|
||||||
|
|
||||||
|
# 95% confidence interval
|
||||||
|
low = bootstrap_means[int(len(bootstrap_means) * 0.025)]
|
||||||
|
high = bootstrap_means[int(len(bootstrap_means) * 0.975)]
|
||||||
|
|
||||||
|
print(f"Mean: {mean:.1f} words/s (95% CI: {low-mean:.1f} +{high-mean:.1f})")
|
||||||
|
|
||||||
|
|
||||||
|
def print_outliers(sample: numpy.ndarray):
|
||||||
|
quartiles = Quartiles(sample)
|
||||||
|
|
||||||
|
n_outliers = numpy.sum(
|
||||||
|
(sample < (quartiles.q1 - 1.5 * quartiles.iqr))
|
||||||
|
| (sample > (quartiles.q3 + 1.5 * quartiles.iqr))
|
||||||
|
)
|
||||||
|
n_extreme_outliers = numpy.sum(
|
||||||
|
(sample < (quartiles.q1 - 3.0 * quartiles.iqr))
|
||||||
|
| (sample > (quartiles.q3 + 3.0 * quartiles.iqr))
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Outliers: {(100 * n_outliers) / len(sample):.1f}%, extreme outliers: {(100 * n_extreme_outliers) / len(sample)}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warmup(
|
||||||
|
nlp: Language, docs: List[Doc], warmup_epochs: int, batch_size: Optional[int]
|
||||||
|
) -> numpy.ndarray:
|
||||||
|
docs = warmup_epochs * docs
|
||||||
|
return annotate(nlp, docs, batch_size)
|
|
@ -7,12 +7,15 @@ from thinc.api import fix_random_seed
|
||||||
|
|
||||||
from ..training import Corpus
|
from ..training import Corpus
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ._util import app, Arg, Opt, setup_gpu, import_code
|
from ._util import app, Arg, Opt, setup_gpu, import_code, benchmark_cli
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
|
|
||||||
|
|
||||||
|
@benchmark_cli.command(
|
||||||
|
"accuracy",
|
||||||
|
)
|
||||||
@app.command("evaluate")
|
@app.command("evaluate")
|
||||||
def evaluate_cli(
|
def evaluate_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -36,7 +39,7 @@ def evaluate_cli(
|
||||||
dependency parses in a HTML file, set as output directory as the
|
dependency parses in a HTML file, set as output directory as the
|
||||||
displacy_path argument.
|
displacy_path argument.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/cli#evaluate
|
DOCS: https://spacy.io/api/cli#benchmark-accuracy
|
||||||
"""
|
"""
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
evaluate(
|
evaluate(
|
||||||
|
|
|
@ -31,3 +31,12 @@ def test_convert_auto_conflict():
|
||||||
assert "All input files must be same type" in result.stdout
|
assert "All input files must be same type" in result.stdout
|
||||||
out_files = os.listdir(d_out)
|
out_files = os.listdir(d_out)
|
||||||
assert len(out_files) == 0
|
assert len(out_files) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_benchmark_accuracy_alias():
|
||||||
|
# Verify that the `evaluate` alias works correctly.
|
||||||
|
result_benchmark = CliRunner().invoke(app, ["benchmark", "accuracy", "--help"])
|
||||||
|
result_evaluate = CliRunner().invoke(app, ["evaluate", "--help"])
|
||||||
|
assert result_benchmark.stdout == result_evaluate.stdout.replace(
|
||||||
|
"spacy evaluate", "spacy benchmark accuracy"
|
||||||
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@ menu:
|
||||||
- ['train', 'train']
|
- ['train', 'train']
|
||||||
- ['pretrain', 'pretrain']
|
- ['pretrain', 'pretrain']
|
||||||
- ['evaluate', 'evaluate']
|
- ['evaluate', 'evaluate']
|
||||||
|
- ['benchmark', 'benchmark']
|
||||||
- ['apply', 'apply']
|
- ['apply', 'apply']
|
||||||
- ['find-threshold', 'find-threshold']
|
- ['find-threshold', 'find-threshold']
|
||||||
- ['assemble', 'assemble']
|
- ['assemble', 'assemble']
|
||||||
|
@ -1135,8 +1136,19 @@ $ python -m spacy pretrain [config_path] [output_dir] [--code] [--resume-path] [
|
||||||
|
|
||||||
## evaluate {id="evaluate",version="2",tag="command"}
|
## evaluate {id="evaluate",version="2",tag="command"}
|
||||||
|
|
||||||
Evaluate a trained pipeline. Expects a loadable spaCy pipeline (package name or
|
The `evaluate` subcommand is superseded by
|
||||||
path) and evaluation data in the
|
[`spacy benchmark accuracy`](#benchmark-accuracy). `evaluate` is provided as an
|
||||||
|
alias to `benchmark accuracy` for compatibility.
|
||||||
|
|
||||||
|
## benchmark {id="benchmark", version="3.5"}
|
||||||
|
|
||||||
|
The `spacy benchmark` CLI includes commands for benchmarking the accuracy and
|
||||||
|
speed of your spaCy pipelines.
|
||||||
|
|
||||||
|
### accuracy {id="benchmark-accuracy", version="3.5", tag="command"}
|
||||||
|
|
||||||
|
Evaluate the accuracy of a trained pipeline. Expects a loadable spaCy pipeline
|
||||||
|
(package name or path) and evaluation data in the
|
||||||
[binary `.spacy` format](/api/data-formats#binary-training). The
|
[binary `.spacy` format](/api/data-formats#binary-training). The
|
||||||
`--gold-preproc` option sets up the evaluation examples with gold-standard
|
`--gold-preproc` option sets up the evaluation examples with gold-standard
|
||||||
sentences and tokens for the predictions. Gold preprocessing helps the
|
sentences and tokens for the predictions. Gold preprocessing helps the
|
||||||
|
@ -1147,7 +1159,7 @@ skew. To render a sample of dependency parses in a HTML file using the
|
||||||
`--displacy-path` argument.
|
`--displacy-path` argument.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-preproc] [--gpu-id] [--displacy-path] [--displacy-limit]
|
$ python -m spacy benchmark accuracy [model] [data_path] [--output] [--code] [--gold-preproc] [--gpu-id] [--displacy-path] [--displacy-limit]
|
||||||
```
|
```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
@ -1163,6 +1175,29 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
| **CREATES** | Training results and optional metrics and visualizations. |
|
| **CREATES** | Training results and optional metrics and visualizations. |
|
||||||
|
|
||||||
|
### speed {id="benchmark-speed", version="3.5", tag="command"}
|
||||||
|
|
||||||
|
Benchmark the speed of a trained pipeline with a 95% confidence interval.
|
||||||
|
Expects a loadable spaCy pipeline (package name or path) and benchmark data in
|
||||||
|
the [binary `.spacy` format](/api/data-formats#binary-training). The pipeline is
|
||||||
|
warmed up before any measurements are taken.
|
||||||
|
|
||||||
|
```cli
|
||||||
|
$ python -m spacy benchmark speed [model] [data_path] [--batch_size] [--no-shuffle] [--gpu-id] [--batches] [--warmup]
|
||||||
|
```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------------- | -------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `model` | Pipeline to benchmark the speed of. Can be a package or a path to a data directory. ~~str (positional)~~ |
|
||||||
|
| `data_path` | Location of benchmark data in spaCy's [binary format](/api/data-formats#training). ~~Path (positional)~~ |
|
||||||
|
| `--batch-size`, `-b` | Set the batch size. If not set, the pipeline's batch size is used. ~~Optional[int] \(option)~~ |
|
||||||
|
| `--no-shuffle` | Do not shuffle documents in the benchmark data. ~~bool (flag)~~ |
|
||||||
|
| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
|
||||||
|
| `--batches` | Number of batches to benchmark on. Defaults to `50`. ~~Optional[int] \(option)~~ |
|
||||||
|
| `--warmup`, `-w` | Iterations over the benchmark data for warmup. Defaults to `3` ~~Optional[int] \(option)~~ |
|
||||||
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
| **PRINTS** | Pipeline speed in words per second with a 95% confidence interval. |
|
||||||
|
|
||||||
## apply {id="apply", version="3.5", tag="command"}
|
## apply {id="apply", version="3.5", tag="command"}
|
||||||
|
|
||||||
Applies a trained pipeline to data and stores the resulting annotated documents
|
Applies a trained pipeline to data and stores the resulting annotated documents
|
||||||
|
@ -1176,7 +1211,7 @@ input formats are:
|
||||||
|
|
||||||
When a directory is provided it is traversed recursively to collect all files.
|
When a directory is provided it is traversed recursively to collect all files.
|
||||||
|
|
||||||
```cli
|
```bash
|
||||||
$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process]
|
$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1194,7 +1229,6 @@ $ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key]
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
| **CREATES** | A `DocBin` with the annotations from the `model` for all the files found in `data-path`. |
|
| **CREATES** | A `DocBin` with the annotations from the `model` for all the files found in `data-path`. |
|
||||||
|
|
||||||
|
|
||||||
## find-threshold {id="find-threshold",version="3.5",tag="command"}
|
## find-threshold {id="find-threshold",version="3.5",tag="command"}
|
||||||
|
|
||||||
Runs prediction trials for a trained model with varying tresholds to maximize
|
Runs prediction trials for a trained model with varying tresholds to maximize
|
||||||
|
|
Loading…
Reference in New Issue