Fix debug data [ci skip]

This commit is contained in:
Ines Montani 2020-09-29 23:07:11 +02:00
parent a2aa1f6882
commit 9bb958fd0a
2 changed files with 14 additions and 43 deletions

View File

@ -8,12 +8,12 @@ import typer
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli
from ..training import Corpus, Example from ..training import Example
from ..training.initialize import get_sourced_components from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..language import Language from ..language import Language
from ..util import registry from ..util import registry, resolve_dot_names
from .. import util from .. import util
@ -37,8 +37,6 @@ BLANK_MODEL_THRESHOLD = 2000
def debug_data_cli( def debug_data_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
config_path: Path = Arg(..., help="Path to config file", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"), code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"), ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
@ -62,8 +60,6 @@ def debug_data_cli(
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
debug_data( debug_data(
train_path,
dev_path,
config_path, config_path,
config_overrides=overrides, config_overrides=overrides,
ignore_warnings=ignore_warnings, ignore_warnings=ignore_warnings,
@ -74,8 +70,6 @@ def debug_data_cli(
def debug_data( def debug_data(
train_path: Path,
dev_path: Path,
config_path: Path, config_path: Path,
*, *,
config_overrides: Dict[str, Any] = {}, config_overrides: Dict[str, Any] = {},
@ -88,18 +82,11 @@ def debug_data(
no_print=silent, pretty=not no_format, ignore_warnings=ignore_warnings no_print=silent, pretty=not no_format, ignore_warnings=ignore_warnings
) )
# Make sure all files and paths exists if they are needed # Make sure all files and paths exists if they are needed
if not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides) cfg = util.load_config(config_path, overrides=config_overrides)
nlp = util.load_model_from_config(cfg) nlp = util.load_model_from_config(cfg)
T = registry.resolve( config = nlp.config.interpolate()
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
)
# Use original config here, not resolved version # Use original config here, not resolved version
sourced_components = get_sourced_components(cfg) sourced_components = get_sourced_components(cfg)
frozen_components = T["frozen_components"] frozen_components = T["frozen_components"]
@ -109,25 +96,15 @@ def debug_data(
msg.divider("Data file validation") msg.divider("Data file validation")
# Create the gold corpus to be able to better analyze data # Create the gold corpus to be able to better analyze data
loading_train_error_message = "" dot_names = [T["train_corpus"], T["dev_corpus"]]
loading_dev_error_message = "" train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
with msg.loading("Loading corpus..."): train_dataset = list(train_corpus(nlp))
try: dev_dataset = list(dev_corpus(nlp))
train_dataset = list(Corpus(train_path)(nlp))
except ValueError as e:
loading_train_error_message = f"Training data cannot be loaded: {e}"
try:
dev_dataset = list(Corpus(dev_path)(nlp))
except ValueError as e:
loading_dev_error_message = f"Development data cannot be loaded: {e}"
if loading_train_error_message or loading_dev_error_message:
if loading_train_error_message:
msg.fail(loading_train_error_message)
if loading_dev_error_message:
msg.fail(loading_dev_error_message)
sys.exit(1)
msg.good("Corpus is loadable") msg.good("Corpus is loadable")
nlp.initialize(lambda: train_dataset)
msg.good("Pipeline can be initialized with data")
# Create all gold data here to avoid iterating over the train_dataset constantly # Create all gold data here to avoid iterating over the train_dataset constantly
gold_train_data = _compile_gold(train_dataset, factory_names, nlp, make_proj=True) gold_train_data = _compile_gold(train_dataset, factory_names, nlp, make_proj=True)
gold_train_unpreprocessed_data = _compile_gold( gold_train_unpreprocessed_data = _compile_gold(
@ -348,17 +325,11 @@ def debug_data(
msg.divider("Part-of-speech Tagging") msg.divider("Part-of-speech Tagging")
labels = [label for label in gold_train_data["tags"]] labels = [label for label in gold_train_data["tags"]]
# TODO: does this need to be updated? # TODO: does this need to be updated?
tag_map = nlp.vocab.morphology.tag_map msg.info(f"{len(labels)} label(s) in data")
msg.info(f"{len(labels)} label(s) in data ({len(tag_map)} label(s) in tag map)")
labels_with_counts = _format_labels( labels_with_counts = _format_labels(
gold_train_data["tags"].most_common(), counts=True gold_train_data["tags"].most_common(), counts=True
) )
msg.text(labels_with_counts, show=verbose) msg.text(labels_with_counts, show=verbose)
non_tagmap = [l for l in labels if l not in tag_map]
if not non_tagmap:
msg.good(f"All labels present in tag map for language '{nlp.lang}'")
for label in non_tagmap:
msg.fail(f"Label '{label}' not found in tag map for language '{nlp.lang}'")
if "parser" in factory_names: if "parser" in factory_names:
has_low_data_warning = False has_low_data_warning = False

View File

@ -436,6 +436,7 @@ $ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbo
``` ```
=========================== Data format validation =========================== =========================== Data format validation ===========================
✔ Corpus is loadable ✔ Corpus is loadable
✔ Pipeline can be initialized with data
=============================== Training stats =============================== =============================== Training stats ===============================
Training pipeline: tagger, parser, ner Training pipeline: tagger, parser, ner
@ -465,7 +466,7 @@ New: 'ORG' (23860), 'PERSON' (21395), 'GPE' (21193), 'DATE' (18080), 'CARDINAL'
✔ No entities consisting of or starting/ending with whitespace ✔ No entities consisting of or starting/ending with whitespace
=========================== Part-of-speech Tagging =========================== =========================== Part-of-speech Tagging ===========================
49 labels in data (57 labels in tag map) 49 labels in data
'NN' (266331), 'IN' (227365), 'DT' (185600), 'NNP' (164404), 'JJ' (119830), 'NN' (266331), 'IN' (227365), 'DT' (185600), 'NNP' (164404), 'JJ' (119830),
'NNS' (110957), '.' (101482), ',' (92476), 'RB' (90090), 'PRP' (90081), 'VB' 'NNS' (110957), '.' (101482), ',' (92476), 'RB' (90090), 'PRP' (90081), 'VB'
(74538), 'VBD' (68199), 'CC' (62862), 'VBZ' (50712), 'VBP' (43420), 'VBN' (74538), 'VBD' (68199), 'CC' (62862), 'VBZ' (50712), 'VBP' (43420), 'VBN'
@ -476,7 +477,6 @@ New: 'ORG' (23860), 'PERSON' (21395), 'GPE' (21193), 'DATE' (18080), 'CARDINAL'
'-RRB-' (2825), '-LRB-' (2788), 'PDT' (2078), 'XX' (1316), 'RBS' (1142), 'FW' '-RRB-' (2825), '-LRB-' (2788), 'PDT' (2078), 'XX' (1316), 'RBS' (1142), 'FW'
(794), 'NFP' (557), 'SYM' (440), 'WP$' (294), 'LS' (293), 'ADD' (191), 'AFX' (794), 'NFP' (557), 'SYM' (440), 'WP$' (294), 'LS' (293), 'ADD' (191), 'AFX'
(24) (24)
✔ All labels present in tag map for language 'en'
============================= Dependency Parsing ============================= ============================= Dependency Parsing =============================
Found 111703 sentences with an average length of 18.6 words. Found 111703 sentences with an average length of 18.6 words.