diff --git a/MANIFEST.in b/MANIFEST.in
index 6502ff607..ef42138f1 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,8 +1,9 @@
recursive-include include *.h
-recursive-include spacy *.pyx *.pxd *.txt *.cfg
+recursive-include spacy *.pyx *.pxd *.txt *.cfg *.jinja
include LICENSE
include README.md
include pyproject.toml
recursive-exclude spacy/lang *.json
recursive-include spacy/lang *.json.gz
+recursive-include spacy/cli *.json
recursive-include licenses *
diff --git a/pyproject.toml b/pyproject.toml
index d4aa25943..1b4972bd5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
- "thinc>=8.0.0a23,<8.0.0a30",
+ "thinc>=8.0.0a27,<8.0.0a30",
"blis>=0.4.0,<0.5.0",
"pytokenizations",
"smart_open>=2.0.0,<3.0.0"
diff --git a/requirements.txt b/requirements.txt
index 4bb62742d..b4901a692 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
-thinc>=8.0.0a23,<8.0.0a30
+thinc>=8.0.0a27,<8.0.0a30
blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0
@@ -26,3 +26,4 @@ pytest>=4.6.5
pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
+jinja2
diff --git a/setup.cfg b/setup.cfg
index f9da1adb9..a34c34e23 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
- thinc>=8.0.0a23,<8.0.0a30
+ thinc>=8.0.0a27,<8.0.0a30
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
- thinc>=8.0.0a23,<8.0.0a30
+ thinc>=8.0.0a27,<8.0.0a30
blis>=0.4.0,<0.5.0
wasabi>=0.7.1,<1.1.0
srsly>=2.1.0,<3.0.0
diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py
index bc47ffdef..2b21e2f2b 100644
--- a/spacy/cli/__init__.py
+++ b/spacy/cli/__init__.py
@@ -15,7 +15,7 @@ from .debug_model import debug_model # noqa: F401
from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401
from .init_model import init_model # noqa: F401
-from .init_config import init_config # noqa: F401
+from .init_config import init_config, fill_config # noqa: F401
from .validate import validate # noqa: F401
from .project.clone import project_clone # noqa: F401
from .project.assets import project_assets # noqa: F401
diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py
index 93ec9f31e..5613fa317 100644
--- a/spacy/cli/_util.py
+++ b/spacy/cli/_util.py
@@ -179,13 +179,13 @@ def show_validation_error(
file_path: Optional[Union[str, Path]] = None,
*,
title: str = "Config validation error",
- hint_init: bool = True,
+ hint_fill: bool = True,
):
"""Helper to show custom config validation errors on the CLI.
file_path (str / Path): Optional file path of config file, used in hints.
title (str): Title of the custom formatted error.
- hint_init (bool): Show hint about filling config.
+ hint_fill (bool): Show hint about filling config.
"""
try:
yield
@@ -195,14 +195,14 @@ def show_validation_error(
# helper for this in Thinc
err_text = str(e).replace("Config validation error", "").strip()
print(err_text)
- if hint_init and "field required" in err_text:
+ if hint_fill and "field required" in err_text:
config_path = file_path if file_path is not None else "config.cfg"
msg.text(
"If your config contains missing values, you can run the 'init "
- "config' command to fill in all the defaults, if possible:",
+ "fill-config' command to fill in all the defaults, if possible:",
spaced=True,
)
- print(f"{COMMAND} init config {config_path} --base {config_path}\n")
+ print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n")
sys.exit(1)
diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py
index 6c8c85e30..3df224131 100644
--- a/spacy/cli/debug_data.py
+++ b/spacy/cli/debug_data.py
@@ -5,7 +5,6 @@ import sys
import srsly
from wasabi import Printer, MESSAGES, msg, diff_strings
import typer
-from thinc.api import Config
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli, get_sourced_components
@@ -49,7 +48,7 @@ def debug_config_cli(
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
with show_validation_error(config_path):
- config = Config().from_disk(config_path, overrides=overrides)
+ config = util.load_config(config_path, overrides=overrides)
try:
nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
except ValueError as e:
@@ -134,7 +133,7 @@ def debug_data(
if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path):
- cfg = Config().from_disk(config_path, overrides=config_overrides)
+ cfg = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg)
# Use original config here, not resolved version
sourced_components = get_sourced_components(cfg)
diff --git a/spacy/cli/debug_model.py b/spacy/cli/debug_model.py
index cc6cb98ea..0143d2dbe 100644
--- a/spacy/cli/debug_model.py
+++ b/spacy/cli/debug_model.py
@@ -1,7 +1,7 @@
from typing import Dict, Any, Optional
from pathlib import Path
from wasabi import msg
-from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config
+from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
from thinc.api import Model, data_validation
import typer
@@ -49,16 +49,15 @@ def debug_model_cli(
}
config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path):
- cfg = Config().from_disk(config_path, overrides=config_overrides)
+ config = util.load_config(config_path, overrides=config_overrides)
try:
- nlp, config = util.load_model_from_config(cfg)
+ nlp, config = util.load_model_from_config(config_path)
except ValueError as e:
msg.fail(str(e), exits=1)
seed = config["pretraining"]["seed"]
if seed is not None:
msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed)
-
pipe = nlp.get_pipe(component)
if hasattr(pipe, "model"):
model = pipe.model
diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py
index 01664ee40..5fbd237f8 100644
--- a/spacy/cli/init_config.py
+++ b/spacy/cli/init_config.py
@@ -1,81 +1,185 @@
-from typing import Optional, List
+from typing import Optional, List, Tuple
+from enum import Enum
from pathlib import Path
+from wasabi import Printer, diff_strings
from thinc.api import Config
-from wasabi import msg
+from pydantic import BaseModel
+import srsly
+import re
-from ..util import load_model_from_config, get_lang_class, load_model
-from ._util import init_cli, Arg, Opt, show_validation_error
+from .. import util
+from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
+
+
+TEMPLATE_ROOT = Path(__file__).parent / "templates"
+TEMPLATE_PATH = TEMPLATE_ROOT / "quickstart_training.jinja"
+RECOMMENDATIONS_PATH = TEMPLATE_ROOT / "quickstart_training_recommendations.json"
+
+
+class Optimizations(str, Enum):
+ efficiency = "efficiency"
+ accuracy = "accuracy"
+
+
+class RecommendationsTrfItem(BaseModel):
+ name: str
+ size_factor: int
+
+
+class RecommendationsTrf(BaseModel):
+ efficiency: RecommendationsTrfItem
+ accuracy: RecommendationsTrfItem
+
+
+class RecommendationSchema(BaseModel):
+ word_vectors: Optional[str] = None
+ transformer: Optional[RecommendationsTrf] = None
@init_cli.command("config")
def init_config_cli(
# fmt: off
- output_path: Path = Arg("-", help="Output path or - for stdout", allow_dash=True),
- base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False),
- model: Optional[str] = Opt(None, "--model", "-m", help="Optional model to copy config from"),
- lang: Optional[str] = Opt(None, "--lang", "-l", help="Optional language code for blank config"),
- pipeline: Optional[str] = Opt(None, "--pipeline", "-p", help="Optional pipeline components to use")
+ output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
+ lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
+ pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
+ optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
+ cpu: bool = Opt(False, "--cpu", "-C", help="Whether the model needs to run on CPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
# fmt: on
):
- """Generate a starter config.cfg for training."""
- validate_cli_args(base_path, model, lang)
- is_stdout = str(output_path) == "-"
- pipeline = [p.strip() for p in pipeline.split(",")] if pipeline else []
- cfg = init_config(output_path, base_path, model, lang, pipeline, silent=is_stdout)
- if is_stdout:
- print(cfg.to_str())
+ """
+ Generate a starter config.cfg for training. Based on your requirements
+ specified via the CLI arguments, this command generates a config with the
+ optimal settings for you use case. This includes the choice of architecture,
+ pretrained weights and related hyperparameters.
+ """
+ if isinstance(optimize, Optimizations): # instance of enum from the CLI
+ optimize = optimize.value
+ pipeline = [p.strip() for p in pipeline.split(",")]
+ init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
+
+
+@init_cli.command("fill-config")
+def init_fill_config_cli(
+ # fmt: off
+ base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False),
+ output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
+ diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes")
+ # fmt: on
+):
+ """
+ Fill partial config.cfg with default values. Will add all missing settings
+ from the default config and will create all objects, check the registered
+ functions for their default values and update the base config. This command
+ can be used with a config generated via the training quickstart widget:
+ https://nightly.spacy.io/usage/training#quickstart
+ """
+ fill_config(output_file, base_path, diff=diff)
+
+
+def fill_config(
+ output_file: Path, base_path: Path, *, diff: bool = False
+) -> Tuple[Config, Config]:
+ is_stdout = str(output_file) == "-"
+ msg = Printer(no_print=is_stdout)
+ with show_validation_error(hint_fill=False):
+ with msg.loading("Auto-filling config..."):
+ config = util.load_config(base_path)
+ try:
+ nlp, _ = util.load_model_from_config(config, auto_fill=True)
+ except ValueError as e:
+ msg.fail(str(e), exits=1)
+ before = config.to_str()
+ after = nlp.config.to_str()
+ if before == after:
+ msg.warn("Nothing to auto-fill: base config is already complete")
else:
- cfg.to_disk(output_path)
- msg.good("Saved config", output_path)
+ msg.good("Auto-filled config with all values")
+ if diff and not is_stdout:
+ if before == after:
+ msg.warn("No diff to show: nothing was auto-filled")
+ else:
+ msg.divider("START CONFIG DIFF")
+ print("")
+ print(diff_strings(before, after))
+ msg.divider("END CONFIG DIFF")
+ print("")
+ save_config(nlp.config, output_file, is_stdout=is_stdout)
def init_config(
- output_path: Path,
- config_path: Optional[Path],
- model: Optional[str],
- lang: Optional[str],
- pipeline: Optional[List[str]],
- silent: bool = False,
-) -> Config:
- if config_path is not None:
- msg.info("Generating config from base config", show=not silent)
- with show_validation_error(config_path, hint_init=False):
- config = Config().from_disk(config_path)
- try:
- nlp, _ = load_model_from_config(config, auto_fill=True)
- except ValueError as e:
- msg.fail(str(e), exits=1)
- return nlp.config
- if model is not None:
- ext = f" with pipeline {pipeline}" if pipeline else ""
- msg.info(f"Generating config from model {model}{ext}", show=not silent)
- nlp = load_model(model)
- for existing_pipe_name in nlp.pipe_names:
- if existing_pipe_name not in pipeline:
- nlp.remove_pipe(existing_pipe_name)
- for pipe_name in pipeline:
- if pipe_name not in nlp.pipe_names:
- nlp.add_pipe(pipe_name)
- return nlp.config
- if lang is not None:
- ext = f" with pipeline {pipeline}" if pipeline else ""
- msg.info(f"Generating config for language '{lang}'{ext}", show=not silent)
- nlp = get_lang_class(lang)()
- for pipe_name in pipeline:
- nlp.add_pipe(pipe_name)
- return nlp.config
-
-
-def validate_cli_args(
- config_path: Optional[Path], model: Optional[str], lang: Optional[str]
+ output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool
) -> None:
- args = {"--base": config_path, "--model": model, "--lang": lang}
- if sum(arg is not None for arg in args.values()) != 1:
- existing = " ".join(f"{a} {v}" for a, v in args.items() if v is not None)
+ is_stdout = str(output_file) == "-"
+ msg = Printer(no_print=is_stdout)
+ try:
+ from jinja2 import Template
+ except ImportError:
+ msg.fail("This command requires jinja2", "pip install jinja2", exits=1)
+ recommendations = srsly.read_json(RECOMMENDATIONS_PATH)
+ lang_defaults = util.get_lang_class(lang).Defaults
+ has_letters = lang_defaults.writing_system.get("has_letters", True)
+ # Filter out duplicates since tok2vec and transformer are added by template
+ pipeline = [pipe for pipe in pipeline if pipe not in ("tok2vec", "transformer")]
+ reco = RecommendationSchema(**recommendations.get(lang, {})).dict()
+ with TEMPLATE_PATH.open("r") as f:
+ template = Template(f.read())
+ variables = {
+ "lang": lang,
+ "components": pipeline,
+ "optimize": optimize,
+ "hardware": "cpu" if cpu else "gpu",
+ "transformer_data": reco["transformer"],
+ "word_vectors": reco["word_vectors"],
+ "has_letters": has_letters,
+ }
+ base_template = template.render(variables).strip()
+ # Giving up on getting the newlines right in jinja for now
+ base_template = re.sub(r"\n\n\n+", "\n\n", base_template)
+ # Access variables declared in templates
+ template_vars = template.make_module(variables)
+ use_case = {
+ "Language": lang,
+ "Pipeline": ", ".join(pipeline),
+ "Optimize for": optimize,
+ "Hardware": variables["hardware"].upper(),
+ "Transformer": template_vars.transformer.get("name", False),
+ }
+ msg.info("Generated template specific for your use case")
+ for label, value in use_case.items():
+ msg.text(f"- {label}: {value}")
+ use_transformer = bool(template_vars.use_transformer)
+ if use_transformer:
+ require_spacy_transformers(msg)
+ with show_validation_error(hint_fill=False):
+ config = util.load_config_from_str(base_template)
+ try:
+ nlp, _ = util.load_model_from_config(config, auto_fill=True)
+ except ValueError as e:
+ msg.fail(str(e), exits=1)
+ if use_transformer:
+ nlp.config.pop("pretraining", {}) # TODO: solve this better
+ msg.good("Auto-filled config with all values")
+ save_config(nlp.config, output_file, is_stdout=is_stdout)
+
+
+def save_config(config: Config, output_file: Path, is_stdout: bool = False) -> None:
+ msg = Printer(no_print=is_stdout)
+ if is_stdout:
+ print(config.to_str())
+ else:
+ config.to_disk(output_file, interpolate=False)
+ msg.good("Saved config", output_file)
+ msg.text("You can now add your data and train your model:")
+ variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"]
+ print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
+
+
+def require_spacy_transformers(msg: Printer) -> None:
+ try:
+ import spacy_transformers # noqa: F401
+ except ImportError:
msg.fail(
- "The init config command expects only one of the following arguments: "
- "--base (base config to fill and update), --lang (language code to "
- "use for blank config) or --model (base model to copy config from).",
- f"Got: {existing if existing else 'no arguments'}",
+ "Using a transformer-based pipeline requires spacy-transformers "
+ "to be installed.",
exits=1,
)
diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py
index ce0eb27a0..82950f402 100644
--- a/spacy/cli/pretrain.py
+++ b/spacy/cli/pretrain.py
@@ -5,7 +5,7 @@ import time
import re
from collections import Counter
from pathlib import Path
-from thinc.api import use_pytorch_for_gpu_memory, require_gpu, Config
+from thinc.api import use_pytorch_for_gpu_memory, require_gpu
from thinc.api import set_dropout_rate, to_categorical, fix_random_seed
from thinc.api import CosineDistance, L2Distance
from wasabi import msg
@@ -88,7 +88,7 @@ def pretrain(
msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path):
- config = Config().from_disk(config_path, overrides=config_overrides)
+ config = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(config)
# TODO: validate that [pretraining] block exists
if not output_dir.exists():
diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja
new file mode 100644
index 000000000..4f5a2226e
--- /dev/null
+++ b/spacy/cli/templates/quickstart_training.jinja
@@ -0,0 +1,237 @@
+{# This is a template for training configs used for the quickstart widget in
+the docs and the init config command. It encodes various best practices and
+can help generate the best possible configuration, given a user's requirements. #}
+{%- set use_transformer = (transformer_data and hardware != "cpu") -%}
+{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
+[paths]
+train = ""
+dev = ""
+
+[system]
+use_pytorch_for_gpu_memory = {{ "true" if use_transformer else "false" }}
+
+[nlp]
+lang = "{{ lang }}"
+{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components %}
+pipeline = {{ full_pipeline|pprint()|replace("'", '"')|safe }}
+tokenizer = {"@tokenizers": "spacy.Tokenizer.v1"}
+
+[components]
+
+{# TRANSFORMER PIPELINE #}
+{%- if use_transformer -%}
+[components.transformer]
+factory = "transformer"
+
+[components.transformer.model]
+@architectures = "spacy-transformers.TransformerModel.v1"
+name = "{{ transformer["name"] }}"
+tokenizer_config = {"use_fast": true}
+
+[components.transformer.model.get_spans]
+@span_getters = "strided_spans.v1"
+window = 128
+stride = 96
+
+{% if "tagger" in components %}
+[components.tagger]
+factory = "tagger"
+
+[components.tagger.model]
+@architectures = "spacy.Tagger.v1"
+nO = null
+
+[components.tagger.model.tok2vec]
+@architectures = "spacy-transformers.Tok2VecListener.v1"
+grad_factor = 1.0
+
+[components.tagger.model.tok2vec.pooling]
+@layers = "reduce_mean.v1"
+{%- endif %}
+
+{% if "parser" in components -%}
+[components.parser]
+factory = "parser"
+
+[components.parser.model]
+@architectures = "spacy.TransitionBasedParser.v1"
+nr_feature_tokens = 8
+hidden_width = 128
+maxout_pieces = 3
+use_upper = false
+nO = null
+
+[components.parser.model.tok2vec]
+@architectures = "spacy-transformers.Tok2VecListener.v1"
+grad_factor = 1.0
+
+[components.parser.model.tok2vec.pooling]
+@layers = "reduce_mean.v1"
+{%- endif %}
+
+{% if "ner" in components -%}
+[components.ner]
+factory = "ner"
+
+[components.ner.model]
+@architectures = "spacy.TransitionBasedParser.v1"
+nr_feature_tokens = 3
+hidden_width = 64
+maxout_pieces = 2
+use_upper = false
+nO = null
+
+[components.ner.model.tok2vec]
+@architectures = "spacy-transformers.Tok2VecListener.v1"
+grad_factor = 1.0
+
+[components.ner.model.tok2vec.pooling]
+@layers = "reduce_mean.v1"
+{% endif -%}
+
+{# NON-TRANSFORMER PIPELINE #}
+{% else -%}
+
+{%- if hardware == "gpu" -%}
+# There are no recommended transformer weights available for language '{{ lang }}'
+# yet, so the pipeline described here is not transformer-based.
+{%- endif %}
+
+[components.tok2vec]
+factory = "tok2vec"
+
+[components.tok2vec.model]
+@architectures = "spacy.Tok2Vec.v1"
+
+[components.tok2vec.model.embed]
+@architectures = "spacy.MultiHashEmbed.v1"
+width = ${components.tok2vec.model.encode:width}
+rows = {{ 2000 if optimize == "efficiency" else 7000 }}
+also_embed_subwords = {{ true if has_letters else false }}
+also_use_static_vectors = {{ true if optimize == "accuracy" else false }}
+
+[components.tok2vec.model.encode]
+@architectures = "spacy.MaxoutWindowEncoder.v1"
+width = {{ 96 if optimize == "efficiency" else 256 }}
+depth = {{ 4 if optimize == "efficiency" else 8 }}
+window_size = 1
+maxout_pieces = 3
+
+{% if "tagger" in components %}
+[components.tagger]
+factory = "tagger"
+
+[components.tagger.model]
+@architectures = "spacy.Tagger.v1"
+nO = null
+
+[components.tagger.model.tok2vec]
+@architectures = "spacy.Tok2VecListener.v1"
+width = ${components.tok2vec.model.encode:width}
+{%- endif %}
+
+{% if "parser" in components -%}
+[components.parser]
+factory = "parser"
+
+[components.parser.model]
+@architectures = "spacy.TransitionBasedParser.v1"
+nr_feature_tokens = 8
+hidden_width = 128
+maxout_pieces = 3
+use_upper = true
+nO = null
+
+[components.parser.model.tok2vec]
+@architectures = "spacy.Tok2VecListener.v1"
+width = ${components.tok2vec.model.encode:width}
+{%- endif %}
+
+{% if "ner" in components %}
+[components.ner]
+factory = "ner"
+
+[components.ner.model]
+@architectures = "spacy.TransitionBasedParser.v1"
+nr_feature_tokens = 6
+hidden_width = 64
+maxout_pieces = 2
+use_upper = true
+nO = null
+
+[components.ner.model.tok2vec]
+@architectures = "spacy.Tok2VecListener.v1"
+width = ${components.tok2vec.model.encode:width}
+{% endif %}
+{% endif %}
+
+{% for pipe in components %}
+{% if pipe not in ["tagger", "parser", "ner"] %}
+{# Other components defined by the user: we just assume they're factories #}
+[components.{{ pipe }}]
+factory = "{{ pipe }}"
+{% endif %}
+{% endfor %}
+
+[training]
+{% if use_transformer or optimize == "efficiency" or not word_vectors -%}
+vectors = null
+{% else -%}
+vectors = "{{ word_vectors }}"
+{% endif -%}
+{% if use_transformer -%}
+accumulate_gradient = {{ transformer["size_factor"] }}
+{% endif %}
+
+[training.optimizer]
+@optimizers = "Adam.v1"
+
+[training.optimizer.learn_rate]
+@schedules = "warmup_linear.v1"
+warmup_steps = 250
+total_steps = 20000
+initial_rate = 5e-5
+
+[training.train_corpus]
+@readers = "spacy.Corpus.v1"
+path = ${paths:train}
+max_length = {{ 500 if hardware == "gpu" else 0 }}
+
+[training.dev_corpus]
+@readers = "spacy.Corpus.v1"
+path = ${paths:dev}
+max_length = 0
+
+{% if use_transformer %}
+[training.batcher]
+@batchers = "batch_by_padded.v1"
+discard_oversize = true
+size = 2000
+buffer = 256
+{%- else %}
+[training.batcher]
+@batchers = "batch_by_words.v1"
+discard_oversize = false
+tolerance = 0.2
+
+[training.batcher.size]
+@schedules = "compounding.v1"
+start = 100
+stop = 1000
+compound = 1.001
+{% endif %}
+
+[training.score_weights]
+{%- if "tagger" in components %}
+tag_acc = {{ (1.0 / components|length)|round(2) }}
+{%- endif -%}
+{%- if "parser" in components %}
+dep_uas = 0.0
+dep_las = {{ (1.0 / components|length)|round(2) }}
+sents_f = 0.0
+{%- endif %}
+{%- if "ner" in components %}
+ents_f = {{ (1.0 / components|length)|round(2) }}
+ents_p = 0.0
+ents_r = 0.0
+{%- endif -%}
diff --git a/spacy/cli/templates/quickstart_training_recommendations.json b/spacy/cli/templates/quickstart_training_recommendations.json
new file mode 100644
index 000000000..8a3acc438
--- /dev/null
+++ b/spacy/cli/templates/quickstart_training_recommendations.json
@@ -0,0 +1,13 @@
+{
+ "en": {
+ "word_vectors": "en_vectors_web_lg",
+ "transformer": {
+ "efficiency": { "name": "roberta-base", "size_factor": 3 },
+ "accuracy": { "name": "roberta-base", "size_factor": 3 }
+ }
+ },
+ "de": {
+ "word_vectors": null,
+ "transformer": null
+ }
+}
diff --git a/spacy/cli/train.py b/spacy/cli/train.py
index f2085ff80..375e64ffd 100644
--- a/spacy/cli/train.py
+++ b/spacy/cli/train.py
@@ -75,7 +75,7 @@ def train(
msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path):
- config = Config().from_disk(config_path, overrides=config_overrides)
+ config = util.load_config(config_path, overrides=config_overrides)
if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions
diff --git a/spacy/language.py b/spacy/language.py
index efdf633d4..b67c55e3b 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -21,7 +21,7 @@ from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example, validate_examples
from .scorer import Scorer
from .util import create_default_optimizer, registry
-from .util import SimpleFrozenDict, combine_score_weights
+from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES
@@ -36,7 +36,7 @@ from . import about
# This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
-DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
+DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
class BaseDefaults:
@@ -45,7 +45,7 @@ class BaseDefaults:
Language.Defaults.
"""
- config: Config = Config()
+ config: Config = Config(section_order=CONFIG_SECTION_ORDER)
tokenizer_exceptions: Dict[str, List[dict]] = BASE_EXCEPTIONS
prefixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_PREFIXES
suffixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_SUFFIXES
@@ -134,7 +134,7 @@ class Language:
# of the rest.
util.registry._entry_point_factories.get_all()
- self._config = util.deep_merge_configs(self.default_config, DEFAULT_CONFIG)
+ self._config = DEFAULT_CONFIG.merge(self.default_config)
self._meta = dict(meta)
self._path = None
self._optimizer = None
@@ -167,9 +167,7 @@ class Language:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- cls.default_config = util.deep_merge_configs(
- cls.Defaults.config, DEFAULT_CONFIG
- )
+ cls.default_config = DEFAULT_CONFIG.merge(cls.Defaults.config)
cls.default_config["nlp"]["lang"] = cls.lang
@property
@@ -532,6 +530,7 @@ class Language:
name: Optional[str] = None,
*,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
+ raw_config: Optional[Config] = None,
validate: bool = True,
) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and
@@ -542,6 +541,7 @@ class Language:
Defaults to factory name if not set.
config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available.
+ raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component.
@@ -568,7 +568,7 @@ class Language:
# This is unideal, but the alternative would mean you always need to
# specify the full config settings, which is not really viable.
if pipe_meta.default_config:
- config = util.deep_merge_configs(config, pipe_meta.default_config)
+ config = Config(pipe_meta.default_config).merge(config)
# We need to create a top-level key because Thinc doesn't allow resolving
# top-level references to registered functions. Also gives nicer errors.
# The name allows components to know their pipe name and use it in the
@@ -582,12 +582,14 @@ class Language:
cfg = {factory_name: config}
# We're calling the internal _fill here to avoid constructing the
# registered functions twice
- # TODO: customize validation to make it more readable / relate it to
- # pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate)
- filled = filled[factory_name]
+ filled = Config(filled[factory_name])
filled["factory"] = factory_name
filled.pop("@factories", None)
+ # Merge the final filled config with the raw config (including non-
+ # interpolated variables)
+ if raw_config:
+ filled = filled.merge(raw_config)
self._pipe_configs[name] = filled
return resolved[factory_name]
@@ -613,7 +615,10 @@ class Language:
)
)
pipe = source.get_pipe(source_name)
- pipe_config = util.copy_config(source.config["components"][source_name])
+ # Make sure the source config is interpolated so we don't end up with
+ # orphaned variables in our final config
+ source_config = source.config.interpolate()
+ pipe_config = util.copy_config(source_config["components"][source_name])
self._pipe_configs[name] = pipe_config
return pipe, pipe_config["factory"]
@@ -628,6 +633,7 @@ class Language:
last: Optional[bool] = None,
source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
+ raw_config: Optional[Config] = None,
validate: bool = True,
) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are
@@ -649,6 +655,7 @@ class Language:
component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available.
+ raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component.
@@ -678,7 +685,11 @@ class Language:
lang_code=self.lang,
)
pipe_component = self.create_pipe(
- factory_name, name=name, config=config, validate=validate,
+ factory_name,
+ name=name,
+ config=config,
+ raw_config=raw_config,
+ validate=validate,
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
@@ -1379,7 +1390,9 @@ class Language:
DOCS: https://spacy.io/api/language#from_config
"""
if auto_fill:
- config = util.deep_merge_configs(config, cls.default_config)
+ config = Config(
+ cls.default_config, section_order=CONFIG_SECTION_ORDER
+ ).merge(config)
if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config))
config_lang = config["nlp"]["lang"]
@@ -1417,16 +1430,20 @@ class Language:
or lang_cls is not cls
):
raise ValueError(Errors.E943.format(value=type(lang_cls)))
+ # Note that we don't load vectors here, instead they get loaded explicitly
+ # inside stuff like the spacy train function. If we loaded them here,
+ # then we would load them twice at runtime: once when we make from config,
+ # and then again when we load from disk.
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer)
if after_creation is not None:
nlp = after_creation(nlp)
if not isinstance(nlp, cls):
raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
- # Note that we don't load vectors here, instead they get loaded explicitly
- # inside stuff like the spacy train function. If we loaded them here,
- # then we would load them twice at runtime: once when we make from config,
- # and then again when we load from disk.
- pipeline = config.get("components", {})
+ # To create the components we need to use the final interpolated config
+ # so all values are available (if component configs use variables).
+ # Later we replace the component config with the raw config again.
+ interpolated = filled.interpolate() if not filled.is_interpolated else filled
+ pipeline = interpolated.get("components", {})
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
@@ -1435,6 +1452,7 @@ class Language:
opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name])
+ raw_config = Config(filled["components"][pipe_name])
if pipe_name not in disable:
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
@@ -1444,7 +1462,11 @@ class Language:
# The pipe name (key in the config) here is the unique name
# of the component, not necessarily the factory
nlp.add_pipe(
- factory, name=pipe_name, config=pipe_cfg, validate=validate,
+ factory,
+ name=pipe_name,
+ config=pipe_cfg,
+ validate=validate,
+ raw_config=raw_config,
)
else:
model = pipe_cfg["source"]
diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py
index 8e3c95823..1de137e81 100644
--- a/spacy/tests/serialize/test_serialize_config.py
+++ b/spacy/tests/serialize/test_serialize_config.py
@@ -4,7 +4,7 @@ import spacy
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.language import Language
-from spacy.util import registry, deep_merge_configs, load_model_from_config
+from spacy.util import registry, load_model_from_config
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
@@ -194,37 +194,6 @@ def test_serialize_parser():
assert upper.get_dim("nI") == 66
-def test_deep_merge_configs():
- config = {"a": "hello", "b": {"c": "d"}}
- defaults = {"a": "world", "b": {"c": "e", "f": "g"}}
- merged = deep_merge_configs(config, defaults)
- assert len(merged) == 2
- assert merged["a"] == "hello"
- assert merged["b"] == {"c": "d", "f": "g"}
- config = {"a": "hello", "b": {"@test": "x", "foo": 1}}
- defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100}
- merged = deep_merge_configs(config, defaults)
- assert len(merged) == 3
- assert merged["a"] == "hello"
- assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2}
- assert merged["c"] == 100
- config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100}
- defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
- merged = deep_merge_configs(config, defaults)
- assert len(merged) == 3
- assert merged["a"] == "hello"
- assert merged["b"] == {"@test": "x", "foo": 1}
- assert merged["c"] == 100
- # Test that leaving out the factory just adds to existing
- config = {"a": "hello", "b": {"foo": 1}, "c": 100}
- defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
- merged = deep_merge_configs(config, defaults)
- assert len(merged) == 3
- assert merged["a"] == "hello"
- assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2}
- assert merged["c"] == 100
-
-
def test_config_nlp_roundtrip():
"""Test that a config prduced by the nlp object passes training config
validation."""
@@ -311,3 +280,22 @@ def test_config_overrides():
nlp = spacy.load(d)
assert isinstance(nlp, English)
assert nlp.pipe_names == ["tok2vec", "tagger"]
+
+
+def test_config_interpolation():
+ config = Config().from_str(nlp_config_string, interpolate=False)
+ assert config["training"]["train_corpus"]["path"] == "${paths:train}"
+ interpolated = config.interpolate()
+ assert interpolated["training"]["train_corpus"]["path"] == ""
+ nlp = English.from_config(config)
+ assert nlp.config["training"]["train_corpus"]["path"] == "${paths:train}"
+ # Ensure that variables are preserved in nlp config
+ width = "${components.tok2vec.model:width}"
+ assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
+ assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
+ interpolated2 = nlp.config.interpolate()
+ assert interpolated2["training"]["train_corpus"]["path"] == ""
+ assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
+ nlp2 = English.from_config(interpolated)
+ assert nlp2.config["training"]["train_corpus"]["path"] == ""
+ assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py
index b5cc6fff8..1da257fd5 100644
--- a/spacy/tests/test_cli.py
+++ b/spacy/tests/test_cli.py
@@ -1,11 +1,14 @@
import pytest
-
from spacy.gold import docs_to_json, biluo_tags_from_offsets
from spacy.gold.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.lang.en import English
from spacy.schemas import ProjectConfigSchema, validate
from spacy.cli.pretrain import make_docs
+from spacy.cli.init_config import init_config, RECOMMENDATIONS_PATH
+from spacy.cli.init_config import RecommendationSchema
from spacy.cli._util import validate_project_commands, parse_config_overrides
+from spacy.util import get_lang_class
+import srsly
def test_cli_converters_conllu2json():
@@ -319,3 +322,20 @@ def test_parse_config_overrides(args, expected):
def test_parse_config_overrides_invalid(args):
with pytest.raises(SystemExit):
parse_config_overrides(args)
+
+
+@pytest.mark.parametrize("lang", ["en", "nl"])
+@pytest.mark.parametrize(
+ "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
+)
+@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
+def test_init_config(lang, pipeline, optimize):
+ # TODO: add more tests and also check for GPU with transformers
+ init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True)
+
+
+def test_model_recommendations():
+ recommendations = srsly.read_json(RECOMMENDATIONS_PATH)
+ for lang, data in recommendations.items():
+ assert get_lang_class(lang)
+ assert RecommendationSchema(**data)
diff --git a/spacy/util.py b/spacy/util.py
index 09b117952..3cf165a4f 100644
--- a/spacy/util.py
+++ b/spacy/util.py
@@ -58,6 +58,12 @@ if TYPE_CHECKING:
OOV_RANK = numpy.iinfo(numpy.uint64).max
LEXEME_NORM_LANGS = ["da", "de", "el", "en", "id", "lb", "pt", "ru", "sr", "ta", "th"]
+# Default order of sections in the config.cfg. Not all sections needs to exist,
+# and additional sections are added at the end, in alphabetical order.
+# fmt: off
+CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "training", "pretraining"]
+# fmt: on
+
logging.basicConfig()
logger = logging.getLogger("spacy")
@@ -263,9 +269,7 @@ def load_model_from_path(
if not meta:
meta = get_model_meta(model_path)
config_path = model_path / "config.cfg"
- if not config_path.exists() or not config_path.is_file():
- raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
- config = Config().from_disk(config_path, overrides=dict_to_dot(config))
+ config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
return nlp.from_disk(model_path, exclude=disable)
@@ -316,6 +320,29 @@ def load_model_from_init_py(
)
+def load_config(
+ path: Union[str, Path],
+ overrides: Dict[str, Any] = SimpleFrozenDict(),
+ interpolate: bool = False,
+) -> Config:
+ """Load a config file. Takes care of path validation and section order."""
+ config_path = ensure_path(path)
+ if not config_path.exists() or not config_path.is_file():
+ raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
+ return Config(section_order=CONFIG_SECTION_ORDER).from_disk(
+ config_path, overrides=overrides, interpolate=interpolate
+ )
+
+
+def load_config_from_str(
+ text: str, overrides: Dict[str, Any] = SimpleFrozenDict(), interpolate: bool = False
+):
+ """Load a full config from a string."""
+ return Config(section_order=CONFIG_SECTION_ORDER).from_str(
+ text, overrides=overrides, interpolate=interpolate,
+ )
+
+
def get_installed_models() -> List[str]:
"""List all model packages currently installed in the environment.
@@ -901,45 +928,6 @@ def copy_config(config: Union[Dict[str, Any], Config]) -> Config:
raise ValueError(Errors.E961.format(config=config)) from None
-def deep_merge_configs(
- config: Union[Dict[str, Any], Config], defaults: Union[Dict[str, Any], Config]
-) -> Config:
- """Deep merge two configs, a base config and its defaults. Ignores
- references to registered functions to avoid filling in
-
- config (Dict[str, Any]): The config.
- destination (Dict[str, Any]): The config defaults.
- RETURNS (Dict[str, Any]): The merged config.
- """
- config = copy_config(config)
- merged = _deep_merge_configs(config, defaults)
- return Config(merged)
-
-
-def _deep_merge_configs(
- config: Union[Dict[str, Any], Config], defaults: Union[Dict[str, Any], Config]
-) -> Union[Dict[str, Any], Config]:
- for key, value in defaults.items():
- if isinstance(value, dict):
- node = config.setdefault(key, {})
- if not isinstance(node, dict):
- continue
- promises = [key for key in value if key.startswith("@")]
- promise = promises[0] if promises else None
- # We only update the block from defaults if it refers to the same
- # registered function
- if (
- promise
- and any(k.startswith("@") for k in node)
- and (promise in node and node[promise] != value[promise])
- ):
- continue
- defaults = _deep_merge_configs(node, value)
- elif key not in config:
- config[key] = value
- return config
-
-
def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
"""Convert dot notation to a dict. For example: {"token.pos": True,
"token._.xyz": True} becomes {"token": {"pos": True, "_": {"xyz": True }}}.
diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md
index 32aaee7b8..02b618f8a 100644
--- a/website/docs/api/cli.md
+++ b/website/docs/api/cli.md
@@ -101,39 +101,62 @@ files and model directories.
### init config {#init-config new="3"}
-Initialize and export a [`config.cfg` file](/usage/training#config) for training
-and update it with all default values, if possible. Config files used for
-training should always be complete and not contain any hidden defaults or
-missing values, so this command helps you create your final config. It takes
-**one** of the following options:
-
-- `--base`: Base **config** to auto-fill, e.g. created using the
- [training quickstart](/usage/training#quickstart) widget.
-- `--lang`: Base **language** code to use for blank config.
-- `--model`: Base **model** to copy config from.
+Initialize and save a [`config.cfg` file](/usage/training#config) using the
+**recommended settings** for your use case. It works just like the
+[quickstart widget](/usage/training#quickstart), only that it also auto-fills
+all default values and exports a [training](/usage/training#config)-ready
+config. The settings you specify will impact the suggested model architectures
+and pipeline setup, as well as the hyperparameters. You can also adjust and
+customize those settings in your config file later.
> ```bash
-> ### with base config {wrap="true"}
-> $ python -m spacy init config config.cfg --base base.cfg
-> ```
->
-> ```bash
-> ### blank language {wrap="true"}
-> $ python -m spacy init config config.cfg --lang en --pipeline tagger,parser
+> ### Example {wrap="true"}
+> $ python -m spacy init config config.cfg --lang en --pipeline ner,textcat --optimize accuracy
> ```
```bash
-$ python -m spacy init config [output] [--base] [--lang] [--model] [--pipeline]
+$ python -m spacy init config [output_file] [--lang] [--pipeline]
+[--optimize] [--cpu]
```
-| Argument | Type | Description |
-| ------------------ | ---------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `output` | positional | Path to output `.cfg` file. If not set, the config is written to stdout so you can pipe it forward to a file. |
-| `--base`, `-b` | option | Optional base config file to auto-fill with defaults. |
-| `--lang`, `-l` | option | Optional language code to use for blank config. If a `--pipeline` is specified, the components will be added in order. |
-| `--model`, `-m` | option | Optional base model to copy config from. If a `--pipeline` is specified, only those components will be kept, and all other components not in the model will be added. |
-| `--pipeline`, `-p` | option | Optional comma-separated pipeline of components to add to blank language or model. |
-| **CREATES** | config | Complete and auto-filled config file for training. |
+| Argument | Type | Description |
+| ------------------ | ---------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `output_file` | positional | Path to output `.cfg` file. If not set, the config is written to stdout so you can pipe it forward to a file. |
+| `--lang`, `-l` | option | Optional code of the [language](/usage/models#languages) to use. Defaults to `"en"`. |
+| `--pipeline`, `-p` | option | Comma-separated list of trainable [pipeline components](/usage/processing-pipelines#built-in) to include in the model. Defaults to `"tagger,parser,ner"`. |
+| `--optimize`, `-o` | option | `"efficiency"` or `"accuracy"`. Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters. Defaults to `"efficiency"`. |
+| `--cpu`, `-C` | flag | Whether the model needs to run on CPU. This will impact the choice of architecture, pretrained weights and related hyperparameters. |
+| `--help`, `-h` | flag | Show help message and available arguments. |
+| **CREATES** | file | The config file for training. |
+
+### init fill-config {#init-fill-config new="3"}
+
+Auto-fill a partial [`config.cfg` file](/usage/training#config) file with **all
+default values**, e.g. a config generated with the
+[quickstart widget](/usage/training#quickstart). Config files used for training
+should always be complete and not contain any hidden defaults or missing values,
+so this command helps you create your final training config. In order to find
+the available settings and defaults, all functions referenced in the config will
+be created, and their signatures are used to find the defaults. If your config
+contains a problem that can't be resolved automatically, spaCy will show you a
+validation error with more details.
+
+> ```bash
+> ### Example {wrap="true"}
+> $ python -m spacy init fill-config base.cfg config.cfg
+> ```
+
+```bash
+$ python -m spacy init fill-config [base_path] [output_file] [--diff]
+```
+
+| Argument | Type | Description |
+| -------------- | ---------- | ------------------------------------------------------------------------------------------------------------- |
+| `base_path` | positional | Path to base config to fill, e.g. generated by the [quickstart widget](/usage/training#quickstart). |
+| `output_file` | positional | Path to output `.cfg` file. If not set, the config is written to stdout so you can pipe it forward to a file. |
+| `--diff`, `-D` | flag | Print a visual diff highlighting the changes. |
+| `--help`, `-h` | flag | Show help message and available arguments. |
+| **CREATES** | file | Complete and auto-filled config file for training. |
### init model {#init-model new="2"}
diff --git a/website/docs/api/data-formats.md b/website/docs/api/data-formats.md
index 32633330e..6245c219f 100644
--- a/website/docs/api/data-formats.md
+++ b/website/docs/api/data-formats.md
@@ -20,8 +20,9 @@ Config files define the training process and model pipeline and can be passed to
[`spacy train`](/api/cli#train). They use
[Thinc's configuration system](https://thinc.ai/docs/usage-config) under the
hood. For details on how to use training configs, see the
-[usage documentation](/usage/training#config). To get started with a blank
-config or fill a partial config with all defaults, you can use the
+[usage documentation](/usage/training#config). To get started with the
+recommended settings for your use case, check out the
+[quickstart widget](/usage/training#quickstart) or run the
[`init config`](/api/cli#init-config) command.
> #### What does the @ mean?
diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md
index a3a2e7102..fc1624ec1 100644
--- a/website/docs/usage/training.md
+++ b/website/docs/usage/training.md
@@ -37,27 +37,37 @@ The recommended way to train your spaCy models is via the
single [`config.cfg`](#config) **configuration file** that includes all settings
and hyperparameters. You can optionally [overwritten](#config-overrides)
settings on the command line, and load in a Python file to register
-[custom functions](#custom-code) and architectures.
+[custom functions](#custom-code) and architectures. This quickstart widget helps
+you generate a starter config with the **recommended settings** for your
+specific use case. It's also available in spaCy as the
+[`init config`](/api/cli#init-config) command.
-> #### Instructions
+> #### Instructions: widget
>
> 1. Select your requirements and settings.
> 2. Use the buttons at the bottom to save the result to your clipboard or a
> file `base_config.cfg`.
-> 3. Run [`init config`](/api/cli#init-config) to create a full training config.
+> 3. Run [`init fill-config`](/api/cli#init-fill-config) to create a full
+> config.
> 4. Run [`train`](/api/cli#train) with your config and data.
+>
+> #### Instructions: CLI
+>
+> 1. Run the [`init config`](/api/cli#init-config) command and specify your
+> requirements and settings as CLI arguments.
+> 2. Run [`train`](/api/cli#train) with the exported config and data.
import QuickstartTraining from 'widgets/quickstart-training.js'
After you've saved the starter config to a file `base_config.cfg`, you can use
-the [`init config`](/api/cli#init-config) command to fill in the remaining
-defaults. Training configs should always be **complete and without hidden
-defaults**, to keep your experiments reproducible.
+the [`init fill-config`](/api/cli#init-fill-config) command to fill in the
+remaining defaults. Training configs should always be **complete and without
+hidden defaults**, to keep your experiments reproducible.
```bash
-$ python -m spacy init config config.cfg --base base_config.cfg
+$ python -m spacy init fill-config base_config.cfg config.cfg
```
> #### Tip: Debug your data
@@ -70,10 +80,13 @@ $ python -m spacy init config config.cfg --base base_config.cfg
> $ python -m spacy debug data config.cfg --verbose
> ```
-You can now add your data and run [`train`](/api/cli#train) with your config.
-See the [`convert`](/api/cli#convert) command for details on how to convert your
-data to spaCy's binary `.spacy` format. You can either include the data paths in
-the `[paths]` section of your config, or pass them in via the command line.
+Instead of exporting your starter config from the quickstart widget and
+auto-filling it, you can also use the [`init config`](/api/cli#init-config)
+command and specify your requirement and settings and CLI arguments. You can now
+add your data and run [`train`](/api/cli#train) with your config. See the
+[`convert`](/api/cli#convert) command for details on how to convert your data to
+spaCy's binary `.spacy` format. You can either include the data paths in the
+`[paths]` section of your config, or pass them in via the command line.
```bash
$ python -m spacy train config.cfg --output ./output --paths.train ./train.spacy --paths.dev ./dev.spacy
@@ -601,7 +614,7 @@ settings in the block will be passed to the function as keyword arguments. Keep
in mind that the config shouldn't have any hidden defaults and all arguments on
the functions need to be represented in the config. If your function defines
**default argument values**, spaCy is able to auto-fill your config when you run
-[`init config`](/api/cli#init-config).
+[`init fill-config`](/api/cli#init-fill-config).
```ini
### config.cfg (excerpt)
diff --git a/website/docs/usage/transformers.md b/website/docs/usage/transformers.md
index e52417d13..c3130f57b 100644
--- a/website/docs/usage/transformers.md
+++ b/website/docs/usage/transformers.md
@@ -163,8 +163,9 @@ resolved, the function is created and passed into the model as an argument.
Remember that the `config.cfg` used for training should contain **no missing
values** and requires all settings to be defined. You don't want any hidden
defaults creeping in and changing your results! spaCy will tell you if settings
-are missing, and you can run [`spacy init config`](/api/cli#init-config) with to
-automatically fill in all defaults.
+are missing, and you can run
+[`spacy init fill-config`](/api/cli#init-fill-config) to automatically fill in
+all defaults.
diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md
index 919af3ffb..a32f9cd86 100644
--- a/website/docs/usage/v3.md
+++ b/website/docs/usage/v3.md
@@ -152,7 +152,8 @@ The following methods, attributes and commands are new in spaCy v3.0.
| [`Language.config`](/api/language#config) | The [config](/usage/training#config) used to create the current `nlp` object. An instance of [`Config`](https://thinc.ai/docs/api-config#config) and can be saved to disk and used for training. |
| [`Pipe.score`](/api/pipe#score) | Method on trainable pipeline components that returns a dictionary of evaluation scores. |
| [`registry`](/api/top-level#registry) | Function registry to map functions to string names that can be referenced in [configs](/usage/training#config). |
-| [`init config`](/api/cli#init-config) | CLI command for initializing a [training config](/usage/training) file for a blank language or auto-filling a partial config. |
+| [`init config`](/api/cli#init-config) | CLI command for initializing a [training config](/usage/training) file with the recommended settings. |
+| [`init fill-config`](/api/cli#init-fill-config) | CLI command for auto-filling a partial config with all defaults and missing values. |
| [`debug config`](/api/cli#debug-config) | CLI command for debugging a [training config](/usage/training) file and showing validation errors. |
| [`project`](/api/cli#project) | Suite of CLI commands for cloning, running and managing [spaCy projects](/usage/projects). |
diff --git a/website/setup/jinja_to_js.py b/website/setup/jinja_to_js.py
index 459208d9b..a2c896151 100644
--- a/website/setup/jinja_to_js.py
+++ b/website/setup/jinja_to_js.py
@@ -1,4 +1,6 @@
# Forked from: https://github.com/jonbretman/jinja-to-js
+# With additional functionality: in/not in, replace, pprint, round, + for lists,
+# rendering empty dicts
# This script is mostly used to generate the JavaScript function for the
# training quicktart widget.
import contextlib
@@ -315,7 +317,7 @@ class JinjaToJS(object):
if callable(handler):
handler(node, **kwargs)
else:
- raise Exception("Unknown node %s" % node)
+ raise Exception(f"Unknown node {node} ({node_name})")
def _process_extends(self, node, **kwargs):
"""
@@ -431,6 +433,13 @@ class JinjaToJS(object):
self.output.write(node.name)
+ def _process_dict(self, node, **kwargs):
+ with self._interpolation():
+ with self._python_bool_wrapper(**kwargs):
+ if node.items:
+ raise ValueError(f"Can't process non-empty dict in epxression: {node}")
+ self.output.write("{}")
+
def _process_getattr(self, node, **kwargs):
"""
Processes a `GetAttr` node. e.g. {{ foo.bar }}
@@ -697,6 +706,27 @@ class JinjaToJS(object):
self._process_node(node.node, **new_kwargs)
self.output.write(")")
+ def _process_filter_replace(self, node, **kwargs):
+ # We're getting a quoted string from Python/Jinja as the pattern to
+ # replace, but to replace all occurrences in JS, we typically need a
+ # regex, which would be annoying to convert. So we're using split/join
+ # instead here.
+ with self._interpolation():
+ with self._python_bool_wrapper(**kwargs) as new_kwargs:
+ self._process_node(node.node, **new_kwargs)
+ self.output.write(".split(")
+ self._process_node(node.args[0], **new_kwargs)
+ self.output.write(").join(")
+ self._process_node(node.args[1], **new_kwargs)
+ self.output.write(")")
+
+ def _process_filter_pprint(self, node, **kwargs):
+ with self._interpolation():
+ with self._python_bool_wrapper(**kwargs) as new_kwargs:
+ self.output.write("JSON.stringify(")
+ self._process_node(node.node, **new_kwargs)
+ self.output.write(")")
+
def _process_filter_attr(self, node, **kwargs):
with self._interpolation():
with self._python_bool_wrapper(**kwargs) as new_kwargs:
@@ -746,7 +776,10 @@ class JinjaToJS(object):
with self._python_bool_wrapper(**kwargs) as new_kwargs:
self.output.write("Math.round((")
self._process_node(node.node, **new_kwargs)
- self.output.write("+ Number.EPSILON) * 100) / 100")
+ self.output.write("+ Number.EPSILON) * 10**")
+ self._process_node(node.args[0], **new_kwargs)
+ self.output.write(") / 10**")
+ self._process_node(node.args[0], **new_kwargs)
def _process_filter_last(self, node, **kwargs):
with self._interpolation():
@@ -867,8 +900,10 @@ class JinjaToJS(object):
)
with option(kwargs, use_python_bool_wrapper=False):
- if operand.op == "in":
+ if operand.op == "in" or operand.op == "notin":
# Special case for "in" operator
+ if operand.op == "notin":
+ self.output.write("!")
self._process_node(operand.expr, **kwargs)
self.output.write(".includes(")
self._process_node(node.expr, **kwargs)
@@ -1027,7 +1062,18 @@ class JinjaToJS(object):
self.output.write(")")
def _process_add(self, node, **kwargs):
- self._process_math(node, math_operator=" + ", **kwargs)
+ # Handle + operator for lists, which behaves differently in JS. Currently
+ # only works if we have an explicit list node on either side (in which
+ # case we assume both are lists).
+ if isinstance(node.left, nodes.List) or isinstance(node.right, nodes.List):
+ with self._interpolation():
+ with self._python_bool_wrapper(**kwargs) as new_kwargs:
+ self._process_node(node.left, **new_kwargs)
+ self.output.write(".concat(")
+ self._process_node(node.right, **new_kwargs)
+ self.output.write(")")
+ else:
+ self._process_math(node, math_operator=" + ", **kwargs)
def _process_sub(self, node, **kwargs):
self._process_math(node, math_operator=" - ", **kwargs)
@@ -1190,16 +1236,22 @@ def main(
# fmt: off
template_path: Path = typer.Argument(..., exists=True, dir_okay=False, help="Path to .jinja file"),
output: Path = typer.Argument(None, help="Path to output module (stdout if unset)"),
+ data_path: Path = typer.Option(None, "--data", help="Optional JSON file with additional data to be included as DATA")
# fmt: on
):
"""Convert a jinja2 template to a JavaScript module."""
- compiler = JinjaToJS(
- template_path.parent, template_path.parts[-1], js_module_format="es6"
- )
+ data = "{}"
+ if data_path is not None:
+ with data_path.open("r", encoding="utf8") as f:
+ data = json.dumps(json.loads(f.read())) # dump and load for compactness
+ tpl_file = template_path.parts[-1]
+ compiler = JinjaToJS(template_path.parent, tpl_file, js_module_format="es6")
+ header = f"// This file was auto-generated by {__file__} based on {tpl_file}"
+ data_str = f"export const DATA = {data}"
result = compiler.get_output()
if output is not None:
with output.open("w") as f:
- f.write(result)
+ f.write(f"{header}\n{result}\n{data_str}")
print(f"Updated {output.parts[-1]}")
else:
print(result)
diff --git a/website/setup/quickstart_training_cpu.jinja b/website/setup/quickstart_training_cpu.jinja
deleted file mode 100644
index 2bfb80cdd..000000000
--- a/website/setup/quickstart_training_cpu.jinja
+++ /dev/null
@@ -1,107 +0,0 @@
-{# Template for "CPU" configs. The transformer will use a different template. #}
-# This is an auto-generated partial config for training a model.
-# To use it for training, auto-fill it with all default values.
-# python -m spacy init config config.cfg --base base_config.cfg
-[paths]
-train = ""
-dev = ""
-
-[nlp]
-lang = "{{ lang }}"
-pipeline = {{ pipeline|safe }}
-vectors = {{ ('"en_vectors_web_lg"' if optimize == "accuracy" else false)|safe }}
-tokenizer = {"@tokenizers": "spacy.Tokenizer.v1"}
-
-[components]
-
-[components.tok2vec]
-factory = "tok2vec"
-
-[components.tok2vec.model]
-@architectures = "spacy.Tok2Vec.v1"
-
-[components.tok2vec.model.embed]
-@architectures = "spacy.MultiHashEmbed.v1"
-width = ${components.tok2vec.model.encode:width}
-rows = {{ 2000 if optimize == "efficiency" else 7000 }}
-also_embed_subwords = {{ true if has_letters else false }}
-also_use_static_vectors = {{ true if optimize == "accuracy" else false }}
-
-[components.tok2vec.model.encode]
-@architectures = "spacy.MaxoutWindowEncoder.v1"
-width = {{ 96 if optimize == "efficiency" else 256 }}
-depth = {{ 4 if optimize == "efficiency" else 8 }}
-window_size = 1
-maxout_pieces = 3
-
-{% if "tagger" in components %}
-[components.tagger]
-factory = "tagger"
-
-[components.tagger.model]
-@architectures = "spacy.Tagger.v1"
-nO = null
-
-[components.tagger.model.tok2vec]
-@architectures = "spacy.Tok2VecListener.v1"
-width = ${components.tok2vec.model.encode:width}
-{%- endif %}
-
-{% if "parser" in components -%}
-[components.parser]
-factory = "parser"
-
-[components.parser.model]
-@architectures = "spacy.TransitionBasedParser.v1"
-nr_feature_tokens = 8
-hidden_width = 128
-maxout_pieces = 3
-use_upper = true
-nO = null
-
-[components.parser.model.tok2vec]
-@architectures = "spacy.Tok2VecListener.v1"
-width = ${components.tok2vec.model.encode:width}
-{%- endif %}
-
-{% if "ner" in components -%}
-[components.ner]
-factory = "ner"
-
-[components.ner.model]
-@architectures = "spacy.TransitionBasedParser.v1"
-nr_feature_tokens = 6
-hidden_width = 64
-maxout_pieces = 2
-use_upper = true
-nO = null
-
-[components.ner.model.tok2vec]
-@architectures = "spacy.Tok2VecListener.v1"
-width = ${components.tok2vec.model.encode:width}
-{% endif -%}
-
-[training]
-
-[training.train_corpus]
-@readers = "spacy.Corpus.v1"
-path = ${paths:train}
-
-[training.dev_corpus]
-@readers = "spacy.Corpus.v1"
-path = ${paths:dev}
-
-[training.score_weights]
-{%- if "tagger" in components %}
-tag_acc = {{ (1.0 / components|length)|round() }}
-{%- endif -%}
-{%- if "parser" in components %}
-dep_uas = 0.0
-dep_las = {{ (1.0 / components|length)|round() }}
-sents_f = 0.0
-{%- endif %}
-{%- if "ner" in components %}
-ents_f = {{ (1.0 / components|length)|round() }}
-ents_p = 0.0
-ents_r = 0.0
-{%- endif -%}
diff --git a/website/setup/quickstart_training_gpu.jinja b/website/setup/quickstart_training_gpu.jinja
deleted file mode 100644
index 989af980a..000000000
--- a/website/setup/quickstart_training_gpu.jinja
+++ /dev/null
@@ -1,139 +0,0 @@
-{# Template for "CPU" configs. The transformer will use a different template. #}
-# This is an auto-generated partial config for training a model.
-# To use it for training, auto-fill it with all default values.
-# python -m spacy init config config.cfg --base base_config.cfg
-[paths]
-train = ""
-dev = ""
-
-[nlp]
-lang = "{{ lang }}"
-pipeline = {{ pipeline|safe }}
-vectors = null
-tokenizer = {"@tokenizers": "spacy.Tokenizer.v1"}
-
-[components]
-
-[components.transformer]
-factory = "transformer"
-
-[components.transformer.model]
-@architectures = "spacy-transformers.TransformerModel.v1"
-{#- name = {{ transformer_info["name"] }} #}
-name = "roberta-base"
-tokenizer_config = {"use_fast": true}
-
-[components.transformer.model.get_spans]
-@span_getters = "strided_spans.v1"
-window = 128
-stride = 96
-
-{% if "tagger" in components %}
-[components.tagger]
-factory = "tagger"
-
-[components.tagger.model]
-@architectures = "spacy.Tagger.v1"
-nO = null
-
-[components.tagger.model.tok2vec]
-@architectures = "spacy-transformers.TransformerListener.v1"
-grad_factor = 1.0
-
-[components.ner.model.tok2vec.pooling]
-@layers = "reduce_mean.v1"
-{%- endif %}
-
-{% if "parser" in components -%}
-[components.parser]
-factory = "parser"
-
-[components.parser.model]
-@architectures = "spacy.TransitionBasedParser.v1"
-nr_feature_tokens = 8
-hidden_width = 128
-maxout_pieces = 3
-use_upper = false
-nO = null
-
-[components.parser.model.tok2vec]
-@architectures = "spacy-transformers.TransformerListener.v1"
-grad_factor = 1.0
-
-[components.ner.model.tok2vec.pooling]
-@layers = "reduce_mean.v1"
-{%- endif %}
-
-{% if "ner" in components -%}
-[components.ner]
-factory = "ner"
-
-[components.ner.model]
-@architectures = "spacy.TransitionBasedParser.v1"
-nr_feature_tokens = 3
-hidden_width = 64
-maxout_pieces = 2
-use_upper = false
-nO = null
-
-[components.ner.model.tok2vec]
-@architectures = "spacy-transformers.TransformerListener.v1"
-grad_factor = 1.0
-
-[components.parser.model.tok2vec.pooling]
-@layers = "reduce_mean.v1"
-{% endif -%}
-
-[training]
-{#- accumulate_gradient = {{ transformer_info["size_factor"] }} #}
-accumulate_gradient = 3
-
-[training.optimizer]
-@optimizers = "Adam.v1"
-beta1 = 0.9
-beta2 = 0.999
-L2_is_weight_decay = true
-L2 = 0.01
-grad_clip = 1.0
-use_averages = false
-eps = 1e-8
-
-[training.optimizer.learn_rate]
-@schedules = "warmup_linear.v1"
-warmup_steps = 250
-total_steps = 20000
-initial_rate = 5e-5
-
-[training.train_corpus]
-@readers = "spacy.Corpus.v1"
-path = ${paths:train}
-gold_preproc = false
-max_length = 500
-limit = 0
-
-[training.dev_corpus]
-@readers = "spacy.Corpus.v1"
-path = ${paths:dev}
-gold_preproc = false
-max_length = 0
-limit = 0
-
-[training.batcher]
-@batchers = "batch_by_padded.v1"
-discard_oversize = true
-batch_size = 2000
-
-[training.score_weights]
-{%- if "tagger" in components %}
-tag_acc = {{ (1.0 / components|length)|round(2) }}
-{%- endif -%}
-{%- if "parser" in components %}
-dep_uas = 0.0
-dep_las = {{ (1.0 / components|length)|round(2) }}
-sents_f = 0.0
-{%- endif %}
-{%- if "ner" in components %}
-ents_f = {{ (1.0 / components|length)|round(2) }}
-ents_p = 0.0
-ents_r = 0.0
-{%- endif -%}
diff --git a/website/setup/setup.sh b/website/setup/setup.sh
index 1d0e4f9bf..a6bbd3294 100755
--- a/website/setup/setup.sh
+++ b/website/setup/setup.sh
@@ -1 +1 @@
-python jinja_to_js.py quickstart_training_cpu.jinja ../src/widgets/quickstart-training-generator.js
+python jinja_to_js.py ../../spacy/cli/templates/quickstart_training.jinja ../src/widgets/quickstart-training-generator.js --data ../../spacy/cli/templates/quickstart_training_recommendations.json
diff --git a/website/src/styles/quickstart.module.sass b/website/src/styles/quickstart.module.sass
index 9ea112a45..91dd19f85 100644
--- a/website/src/styles/quickstart.module.sass
+++ b/website/src/styles/quickstart.module.sass
@@ -125,9 +125,9 @@
display: block
.small
- font-size: var(--font-size-sm)
+ font-size: var(--font-size-code)
line-height: 1.65
- white-space: pre
+ white-space: pre-wrap
max-height: 400px
overflow-y: auto
diff --git a/website/src/widgets/quickstart-training-generator.js b/website/src/widgets/quickstart-training-generator.js
index a85c72129..c7f856073 100644
--- a/website/src/widgets/quickstart-training-generator.js
+++ b/website/src/widgets/quickstart-training-generator.js
@@ -1,10 +1,12 @@
-import jinjaToJS from "jinja-to-js";export default function templateQuickstartTrainingCpu(ctx) {
+// This file was auto-generated by jinja_to_js.py based on quickstart_training.jinja
+import jinjaToJS from "jinja-to-js";export default function templateQuickstartTraining(ctx) {
var __result = "";
var __tmp;
var __runtime = jinjaToJS.runtime;
var __filters = jinjaToJS.filters;
var __globals = jinjaToJS.globals;
var context = jinjaToJS.createContext(ctx);
- __result += "\n# This is an auto-generated partial config for training a model.\n# To use it for training, auto-fill it with all default values.\n# python -m spacy init config config.cfg --base base_config.cfg\n[paths]\ntrain = \"\"\ndev = \"\"\n\n[nlp]\nlang = \"";__result += "" + __runtime.escape((__tmp = (context.lang)) == null ? "" : __tmp);__result += "\"\npipeline = ";__result += "" + ((__tmp = (context.pipeline)) == null ? "" : __tmp);__result += "\nvectors = ";__result += "" + ((__tmp = ((context.optimize==="accuracy" ? "\"en_vectors_web_lg\"" : false))) == null ? "" : __tmp);__result += "\ntokenizer = {\"@tokenizers\": \"spacy.Tokenizer.v1\"}\n\n[components]\n\n[components.tok2vec]\nfactory = \"tok2vec\"\n\n[components.tok2vec.model]\n@architectures = \"spacy.Tok2Vec.v1\"\n\n[components.tok2vec.model.embed]\n@architectures = \"spacy.MultiHashEmbed.v1\"\nwidth = ${components.tok2vec.model.encode:width}\nrows = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 2000 : 7000))) == null ? "" : __tmp);__result += "\nalso_embed_subwords = ";__result += "" + __runtime.escape((__tmp = ((context.has_letters ? true : false))) == null ? "" : __tmp);__result += "\nalso_use_static_vectors = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="accuracy" ? true : false))) == null ? "" : __tmp);__result += "\n\n[components.tok2vec.model.encode]\n@architectures = \"spacy.MaxoutWindowEncoder.v1\"\nwidth = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 96 : 256))) == null ? "" : __tmp);__result += "\ndepth = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 4 : 8))) == null ? "" : __tmp);__result += "\nwindow_size = 1\nmaxout_pieces = 3\n\n";if(context.components.includes("tagger")){__result += "\n[components.tagger]\nfactory = \"tagger\"\n\n[components.tagger.model]\n@architectures = \"spacy.Tagger.v1\"\nnO = null\n\n[components.tagger.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}";}__result += "\n\n";if(context.components.includes("parser")){__result += "[components.parser]\nfactory = \"parser\"\n\n[components.parser.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 8\nhidden_width = 128\nmaxout_pieces = 3\nuse_upper = true\nnO = null\n\n[components.parser.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}";}__result += "\n\n";if(context.components.includes("ner")){__result += "[components.ner]\nfactory = \"ner\"\n\n[components.ner.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 6\nhidden_width = 64\nmaxout_pieces = 2\nuse_upper = true\nnO = null\n\n[components.parser.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}\n";}__result += "[training]\n\n[training.train_corpus]\n@readers = \"spacy.Corpus.v1\"\npath = ${paths:train}\n\n[training.dev_corpus]\n@readers = \"spacy.Corpus.v1\"\npath = ${paths:dev}\n\n[training.score_weights]";if(context.components.includes("tagger")){__result += "\ntag_acc = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 100) / 100)) == null ? "" : __tmp);}if(context.components.includes("parser")){__result += "\ndep_uas = 0.0\ndep_las = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 100) / 100)) == null ? "" : __tmp);__result += "\nsents_f = 0.0";}if(context.components.includes("ner")){__result += "\nents_f = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 100) / 100)) == null ? "" : __tmp);__result += "\nents_p = 0.0\nents_r = 0.0";}
+ var use_transformer = context.transformer_data && context.hardware!=="cpu";var transformer = (use_transformer ? context.transformer_data[context.optimize] : {});__result += "[paths]\ntrain = \"\"\ndev = \"\"\n\n[system]\nuse_pytorch_for_gpu_memory = ";__result += "" + __runtime.escape((__tmp = ((use_transformer ? "true" : "false"))) == null ? "" : __tmp);__result += "\n\n[nlp]\nlang = \"";__result += "" + __runtime.escape((__tmp = (context.lang)) == null ? "" : __tmp);__result += "\"";var full_pipeline = [(use_transformer ? "transformer" : "tok2vec")].concat(context.components);__result += "\npipeline = ";__result += "" + ((__tmp = (JSON.stringify(full_pipeline).split("'").join("\""))) == null ? "" : __tmp);__result += "\ntokenizer = {\"@tokenizers\": \"spacy.Tokenizer.v1\"}\n\n[components]\n\n";if(__runtime.boolean(use_transformer)){__result += "[components.transformer]\nfactory = \"transformer\"\n\n[components.transformer.model]\n@architectures = \"spacy-transformers.TransformerModel.v1\"\nname = \"";__result += "" + __runtime.escape((__tmp = (transformer["name"])) == null ? "" : __tmp);__result += "\"\ntokenizer_config = {\"use_fast\": true}\n\n[components.transformer.model.get_spans]\n@span_getters = \"strided_spans.v1\"\nwindow = 128\nstride = 96\n\n";if(context.components.includes("tagger")){__result += "\n[components.tagger]\nfactory = \"tagger\"\n\n[components.tagger.model]\n@architectures = \"spacy.Tagger.v1\"\nnO = null\n\n[components.tagger.model.tok2vec]\n@architectures = \"spacy-transformers.Tok2VecListener.v1\"\ngrad_factor = 1.0\n\n[components.tagger.model.tok2vec.pooling]\n@layers = \"reduce_mean.v1\"";}__result += "\n\n";if(context.components.includes("parser")){__result += "[components.parser]\nfactory = \"parser\"\n\n[components.parser.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 8\nhidden_width = 128\nmaxout_pieces = 3\nuse_upper = false\nnO = null\n\n[components.parser.model.tok2vec]\n@architectures = \"spacy-transformers.Tok2VecListener.v1\"\ngrad_factor = 1.0\n\n[components.parser.model.tok2vec.pooling]\n@layers = \"reduce_mean.v1\"";}__result += "\n\n";if(context.components.includes("ner")){__result += "[components.ner]\nfactory = \"ner\"\n\n[components.ner.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 3\nhidden_width = 64\nmaxout_pieces = 2\nuse_upper = false\nnO = null\n\n[components.ner.model.tok2vec]\n@architectures = \"spacy-transformers.Tok2VecListener.v1\"\ngrad_factor = 1.0\n\n[components.ner.model.tok2vec.pooling]\n@layers = \"reduce_mean.v1\"\n";}__result += "\n";} else {if(context.hardware==="gpu"){__result += "# There are no recommended transformer weights available for language '";__result += "" + __runtime.escape((__tmp = (context.lang)) == null ? "" : __tmp);__result += "'\n# yet, so the pipeline described here is not transformer-based.";}__result += "\n\n[components.tok2vec]\nfactory = \"tok2vec\"\n\n[components.tok2vec.model]\n@architectures = \"spacy.Tok2Vec.v1\"\n\n[components.tok2vec.model.embed]\n@architectures = \"spacy.MultiHashEmbed.v1\"\nwidth = ${components.tok2vec.model.encode:width}\nrows = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 2000 : 7000))) == null ? "" : __tmp);__result += "\nalso_embed_subwords = ";__result += "" + __runtime.escape((__tmp = ((context.has_letters ? true : false))) == null ? "" : __tmp);__result += "\nalso_use_static_vectors = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="accuracy" ? true : false))) == null ? "" : __tmp);__result += "\n\n[components.tok2vec.model.encode]\n@architectures = \"spacy.MaxoutWindowEncoder.v1\"\nwidth = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 96 : 256))) == null ? "" : __tmp);__result += "\ndepth = ";__result += "" + __runtime.escape((__tmp = ((context.optimize==="efficiency" ? 4 : 8))) == null ? "" : __tmp);__result += "\nwindow_size = 1\nmaxout_pieces = 3\n\n";if(context.components.includes("tagger")){__result += "\n[components.tagger]\nfactory = \"tagger\"\n\n[components.tagger.model]\n@architectures = \"spacy.Tagger.v1\"\nnO = null\n\n[components.tagger.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}";}__result += "\n\n";if(context.components.includes("parser")){__result += "[components.parser]\nfactory = \"parser\"\n\n[components.parser.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 8\nhidden_width = 128\nmaxout_pieces = 3\nuse_upper = true\nnO = null\n\n[components.parser.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}";}__result += "\n\n";if(context.components.includes("ner")){__result += "\n[components.ner]\nfactory = \"ner\"\n\n[components.ner.model]\n@architectures = \"spacy.TransitionBasedParser.v1\"\nnr_feature_tokens = 6\nhidden_width = 64\nmaxout_pieces = 2\nuse_upper = true\nnO = null\n\n[components.ner.model.tok2vec]\n@architectures = \"spacy.Tok2VecListener.v1\"\nwidth = ${components.tok2vec.model.encode:width}\n";}__result += "\n";}__result += "\n\n";__runtime.each(context.components,function(pipe){var __$0 = context.pipe;context.pipe = pipe;__result += "\n";if(!["tagger","parser","ner"].includes(pipe)){__result += "\n";__result += "\n[components.";__result += "" + __runtime.escape((__tmp = (pipe)) == null ? "" : __tmp);__result += "]\nfactory = \"";__result += "" + __runtime.escape((__tmp = (pipe)) == null ? "" : __tmp);__result += "\"\n";}__result += "\n";context.pipe = __$0;});__result += "\n\n[training]\n";if(__runtime.boolean(use_transformer) || context.optimize==="efficiency" || !__runtime.boolean(context.word_vectors)){__result += "vectors = null\n";} else {__result += "vectors = \"";__result += "" + __runtime.escape((__tmp = (context.word_vectors)) == null ? "" : __tmp);__result += "\"\n";}if(__runtime.boolean(use_transformer)){__result += "accumulate_gradient = ";__result += "" + __runtime.escape((__tmp = (transformer["size_factor"])) == null ? "" : __tmp);__result += "\n";}__result += "\n\n[training.optimizer]\n@optimizers = \"Adam.v1\"\n\n[training.optimizer.learn_rate]\n@schedules = \"warmup_linear.v1\"\nwarmup_steps = 250\ntotal_steps = 20000\ninitial_rate = 5e-5\n\n[training.train_corpus]\n@readers = \"spacy.Corpus.v1\"\npath = ${paths:train}\nmax_length = ";__result += "" + __runtime.escape((__tmp = ((context.hardware==="gpu" ? 500 : 0))) == null ? "" : __tmp);__result += "\n\n[training.dev_corpus]\n@readers = \"spacy.Corpus.v1\"\npath = ${paths:dev}\nmax_length = 0\n\n";if(__runtime.boolean(use_transformer)){__result += "\n[training.batcher]\n@batchers = \"batch_by_padded.v1\"\ndiscard_oversize = true\nsize = 2000\nbuffer = 256";} else {__result += "\n[training.batcher]\n@batchers = \"batch_by_words.v1\"\ndiscard_oversize = false\ntolerance = 0.2\n\n[training.batcher.size]\n@schedules = \"compounding.v1\"\nstart = 100\nstop = 1000\ncompound = 1.001\n";}__result += "\n\n[training.score_weights]";if(context.components.includes("tagger")){__result += "\ntag_acc = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 10**2) / 10**2)) == null ? "" : __tmp);}if(context.components.includes("parser")){__result += "\ndep_uas = 0.0\ndep_las = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 10**2) / 10**2)) == null ? "" : __tmp);__result += "\nsents_f = 0.0";}if(context.components.includes("ner")){__result += "\nents_f = ";__result += "" + __runtime.escape((__tmp = (Math.round((1.0 / __filters.size(context.components)+ Number.EPSILON) * 10**2) / 10**2)) == null ? "" : __tmp);__result += "\nents_p = 0.0\nents_r = 0.0";}
return __result;
-}
\ No newline at end of file
+}
+export const DATA = {"en": {"word_vectors": "en_vectors_web_lg", "transformer": {"efficiency": {"name": "roberta-base", "size_factor": 3}, "accuracy": {"name": "roberta-base", "size_factor": 3}}}, "de": {"word_vectors": null, "transformer": null}}
\ No newline at end of file
diff --git a/website/src/widgets/quickstart-training.js b/website/src/widgets/quickstart-training.js
index 46bfd9dc1..4e379e5ec 100644
--- a/website/src/widgets/quickstart-training.js
+++ b/website/src/widgets/quickstart-training.js
@@ -2,14 +2,17 @@ import React, { useState } from 'react'
import { StaticQuery, graphql } from 'gatsby'
import highlightCode from 'gatsby-remark-prismjs/highlight-code.js'
-import { Quickstart, QS } from '../components/quickstart'
-import generator from './quickstart-training-generator'
+import { Quickstart } from '../components/quickstart'
+import generator, { DATA as GENERATOR_DATA } from './quickstart-training-generator'
import { isString, htmlToReact } from '../components/util'
const DEFAULT_LANG = 'en'
const DEFAULT_HARDWARE = 'gpu'
const DEFAULT_OPT = 'efficiency'
const COMPONENTS = ['tagger', 'parser', 'ner', 'textcat']
+const COMMENT = `# This is an auto-generated partial config. To use it with 'spacy train'
+# you can run spacy init fill-config to auto-fill all default settings:
+# python -m spacy init fill-config ./base_config.cfg ./config.cfg`
const DATA = [
{
@@ -61,14 +64,17 @@ export default function QuickstartTraining({ id, title, download = 'config.cfg'
hardware: setHardware,
optimize: setOptimize,
}
+ const reco = GENERATOR_DATA[lang] || {}
const content = generator({
lang,
- pipeline: stringify(components),
components,
optimize,
hardware,
+ transformer_data: reco.transformer,
+ word_vectors: reco.word_vectors,
})
- const rawContent = content.trim().replace(/\n\n\n+/g, '\n\n')
+ const rawStr = content.trim().replace(/\n\n\n+/g, '\n\n')
+ const rawContent = `${COMMENT}\n${rawStr}`
const displayContent = highlightCode('ini', rawContent)
.split('\n')
.map(line => (line.startsWith('#') ? `` : line))