Merge pull request #5930 from svlandeg/feature/init-config-fix

UX for init config
This commit is contained in:
Ines Montani 2020-08-21 12:06:33 +02:00 committed by GitHub
commit 3826cfb8fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 9 deletions

View File

@ -24,7 +24,7 @@ class Optimizations(str, Enum):
@init_cli.command("config")
def init_config_cli(
# fmt: off
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
output_file: Path = Arg(..., help="File to save config.cfg to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
@ -110,6 +110,13 @@ def init_config(
"word_vectors": reco["word_vectors"],
"has_letters": reco["has_letters"],
}
if variables["transformer_data"] and not has_spacy_transformers():
msg.warn(
"To generate a more effective transformer-based config (GPU-only), "
"install the spacy-transformers package and re-run this command. "
"The config generated now does not use transformers."
)
variables["transformer_data"] = None
base_template = template.render(variables).strip()
# Giving up on getting the newlines right in jinja for now
base_template = re.sub(r"\n\n\n+", "\n\n", base_template)
@ -126,8 +133,6 @@ def init_config(
for label, value in use_case.items():
msg.text(f"- {label}: {value}")
use_transformer = bool(template_vars.use_transformer)
if use_transformer:
require_spacy_transformers(msg)
with show_validation_error(hint_fill=False):
config = util.load_config_from_str(base_template)
nlp, _ = util.load_model_from_config(config, auto_fill=True)
@ -149,12 +154,10 @@ def save_config(config: Config, output_file: Path, is_stdout: bool = False) -> N
print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
def require_spacy_transformers(msg: Printer) -> None:
def has_spacy_transformers() -> bool:
try:
import spacy_transformers # noqa: F401
return True
except ImportError:
msg.fail(
"Using a transformer-based pipeline requires spacy-transformers "
"to be installed.",
exits=1,
)
return False