Update debug data further for v3 (#7602)

* Update debug data further for v3

* Remove new/existing label distinction (new labels are not immediately
distinguishable because the pipeline is already initialized)
* Warn on missing labels in training data for all components except parser
* Separate textcat and textcat_multilabel sections
* Add section for morphologizer

* Reword missing label warnings
This commit is contained in:
Adriane Boyd 2021-04-09 11:53:42 +02:00 committed by GitHub
parent 2516896849
commit 73a8c0f992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 122 additions and 46 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Sequence, Dict, Any, Tuple, Optional from typing import List, Sequence, Dict, Any, Tuple, Optional, Set
from pathlib import Path from pathlib import Path
from collections import Counter from collections import Counter
import sys import sys
@ -13,6 +13,8 @@ from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..pipeline._parser_internals.nonproj import DELIMITER from ..pipeline._parser_internals.nonproj import DELIMITER
from ..pipeline import Morphologizer
from ..morphology import Morphology
from ..language import Language from ..language import Language
from ..util import registry, resolve_dot_names from ..util import registry, resolve_dot_names
from .. import util from .. import util
@ -194,32 +196,32 @@ def debug_data(
) )
label_counts = gold_train_data["ner"] label_counts = gold_train_data["ner"]
model_labels = _get_labels_from_model(nlp, "ner") model_labels = _get_labels_from_model(nlp, "ner")
new_labels = [l for l in labels if l not in model_labels]
existing_labels = [l for l in labels if l in model_labels]
has_low_data_warning = False has_low_data_warning = False
has_no_neg_warning = False has_no_neg_warning = False
has_ws_ents_error = False has_ws_ents_error = False
has_punct_ents_warning = False has_punct_ents_warning = False
msg.divider("Named Entity Recognition") msg.divider("Named Entity Recognition")
msg.info( msg.info(f"{len(model_labels)} label(s)")
f"{len(new_labels)} new label(s), {len(existing_labels)} existing label(s)"
)
missing_values = label_counts["-"] missing_values = label_counts["-"]
msg.text(f"{missing_values} missing value(s) (tokens with '-' label)") msg.text(f"{missing_values} missing value(s) (tokens with '-' label)")
for label in new_labels: for label in labels:
if len(label) == 0: if len(label) == 0:
msg.fail("Empty label found in new labels") msg.fail("Empty label found in train data")
if new_labels: labels_with_counts = [
labels_with_counts = [ (label, count)
(label, count) for label, count in label_counts.most_common()
for label, count in label_counts.most_common() if label != "-"
if label != "-" ]
] labels_with_counts = _format_labels(labels_with_counts, counts=True)
labels_with_counts = _format_labels(labels_with_counts, counts=True) msg.text(f"Labels in train data: {_format_labels(labels)}", show=verbose)
msg.text(f"New: {labels_with_counts}", show=verbose) missing_labels = model_labels - labels
if existing_labels: if missing_labels:
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose) msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
if gold_train_data["ws_ents"]: if gold_train_data["ws_ents"]:
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans") msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
has_ws_ents_error = True has_ws_ents_error = True
@ -228,10 +230,10 @@ def debug_data(
msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation") msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
has_punct_ents_warning = True has_punct_ents_warning = True
for label in new_labels: for label in labels:
if label_counts[label] <= NEW_LABEL_THRESHOLD: if label_counts[label] <= NEW_LABEL_THRESHOLD:
msg.warn( msg.warn(
f"Low number of examples for new label '{label}' ({label_counts[label]})" f"Low number of examples for label '{label}' ({label_counts[label]})"
) )
has_low_data_warning = True has_low_data_warning = True
@ -276,22 +278,52 @@ def debug_data(
) )
if "textcat" in factory_names: if "textcat" in factory_names:
msg.divider("Text Classification") msg.divider("Text Classification (Exclusive Classes)")
labels = [label for label in gold_train_data["cats"]] labels = _get_labels_from_model(nlp, "textcat")
model_labels = _get_labels_from_model(nlp, "textcat") msg.info(f"Text Classification: {len(labels)} label(s)")
new_labels = [l for l in labels if l not in model_labels] msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
existing_labels = [l for l in labels if l in model_labels] labels_with_counts = _format_labels(
msg.info( gold_train_data["cats"].most_common(), counts=True
f"Text Classification: {len(new_labels)} new label(s), "
f"{len(existing_labels)} existing label(s)"
) )
if new_labels: msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
labels_with_counts = _format_labels( missing_labels = labels - set(gold_train_data["cats"].keys())
gold_train_data["cats"].most_common(), counts=True if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
if gold_train_data["n_cats_multilabel"] > 0:
# Note: you should never get here because you run into E895 on
# initialization first.
msg.warn(
"The train data contains instances without "
"mutually-exclusive classes. Use the component "
"'textcat_multilabel' instead of 'textcat'."
)
if gold_dev_data["n_cats_multilabel"] > 0:
msg.fail(
"Train/dev mismatch: the dev data contains instances "
"without mutually-exclusive classes while the train data "
"contains only instances with mutually-exclusive classes."
)
if "textcat_multilabel" in factory_names:
msg.divider("Text Classification (Multilabel)")
labels = _get_labels_from_model(nlp, "textcat_multilabel")
msg.info(f"Text Classification: {len(labels)} label(s)")
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
labels_with_counts = _format_labels(
gold_train_data["cats"].most_common(), counts=True
)
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
missing_labels = labels - set(gold_train_data["cats"].keys())
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
) )
msg.text(f"New: {labels_with_counts}", show=verbose)
if existing_labels:
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]): if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
msg.fail( msg.fail(
f"The train and dev labels are not the same. " f"The train and dev labels are not the same. "
@ -299,11 +331,6 @@ def debug_data(
f"Dev labels: {_format_labels(gold_dev_data['cats'])}." f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
) )
if gold_train_data["n_cats_multilabel"] > 0: if gold_train_data["n_cats_multilabel"] > 0:
msg.info(
"The train data contains instances without "
"mutually-exclusive classes. Use '--textcat-multilabel' "
"when training."
)
if gold_dev_data["n_cats_multilabel"] == 0: if gold_dev_data["n_cats_multilabel"] == 0:
msg.warn( msg.warn(
"Potential train/dev mismatch: the train data contains " "Potential train/dev mismatch: the train data contains "
@ -311,9 +338,10 @@ def debug_data(
"dev data does not." "dev data does not."
) )
else: else:
msg.info( msg.warn(
"The train data contains only instances with " "The train data contains only instances with "
"mutually-exclusive classes." "mutually-exclusive classes. You can potentially use the "
"component 'textcat' instead of 'textcat_multilabel'."
) )
if gold_dev_data["n_cats_multilabel"] > 0: if gold_dev_data["n_cats_multilabel"] > 0:
msg.fail( msg.fail(
@ -325,13 +353,37 @@ def debug_data(
if "tagger" in factory_names: if "tagger" in factory_names:
msg.divider("Part-of-speech Tagging") msg.divider("Part-of-speech Tagging")
labels = [label for label in gold_train_data["tags"]] labels = [label for label in gold_train_data["tags"]]
# TODO: does this need to be updated? model_labels = _get_labels_from_model(nlp, "tagger")
msg.info(f"{len(labels)} label(s) in data") msg.info(f"{len(labels)} label(s) in train data")
missing_labels = model_labels - set(labels)
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
labels_with_counts = _format_labels( labels_with_counts = _format_labels(
gold_train_data["tags"].most_common(), counts=True gold_train_data["tags"].most_common(), counts=True
) )
msg.text(labels_with_counts, show=verbose) msg.text(labels_with_counts, show=verbose)
if "morphologizer" in factory_names:
msg.divider("Morphologizer (POS+Morph)")
labels = [label for label in gold_train_data["morphs"]]
model_labels = _get_labels_from_model(nlp, "morphologizer")
msg.info(f"{len(labels)} label(s) in train data")
missing_labels = model_labels - set(labels)
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
labels_with_counts = _format_labels(
gold_train_data["morphs"].most_common(), counts=True
)
msg.text(labels_with_counts, show=verbose)
if "parser" in factory_names: if "parser" in factory_names:
has_low_data_warning = False has_low_data_warning = False
msg.divider("Dependency Parsing") msg.divider("Dependency Parsing")
@ -491,6 +543,7 @@ def _compile_gold(
"ner": Counter(), "ner": Counter(),
"cats": Counter(), "cats": Counter(),
"tags": Counter(), "tags": Counter(),
"morphs": Counter(),
"deps": Counter(), "deps": Counter(),
"words": Counter(), "words": Counter(),
"roots": Counter(), "roots": Counter(),
@ -544,13 +597,36 @@ def _compile_gold(
data["ner"][combined_label] += 1 data["ner"][combined_label] += 1
elif label == "-": elif label == "-":
data["ner"]["-"] += 1 data["ner"]["-"] += 1
if "textcat" in factory_names: if "textcat" in factory_names or "textcat_multilabel" in factory_names:
data["cats"].update(gold.cats) data["cats"].update(gold.cats)
if list(gold.cats.values()).count(1.0) != 1: if list(gold.cats.values()).count(1.0) != 1:
data["n_cats_multilabel"] += 1 data["n_cats_multilabel"] += 1
if "tagger" in factory_names: if "tagger" in factory_names:
tags = eg.get_aligned("TAG", as_string=True) tags = eg.get_aligned("TAG", as_string=True)
data["tags"].update([x for x in tags if x is not None]) data["tags"].update([x for x in tags if x is not None])
if "morphologizer" in factory_names:
pos_tags = eg.get_aligned("POS", as_string=True)
morphs = eg.get_aligned("MORPH", as_string=True)
for pos, morph in zip(pos_tags, morphs):
# POS may align (same value for multiple tokens) when morph
# doesn't, so if either is misaligned (None), treat the
# annotation as missing so that truths doesn't end up with an
# unknown morph+POS combination
if pos is None or morph is None:
pass
# If both are unset, the annotation is missing (empty morph
# converted from int is "_" rather than "")
elif pos == "" and morph == "":
pass
# Otherwise, generate the combined label
else:
label_dict = Morphology.feats_to_dict(morph)
if pos:
label_dict[Morphologizer.POS_FEAT] = pos
label = eg.reference.vocab.strings[
eg.reference.vocab.morphology.add(label_dict)
]
data["morphs"].update([label])
if "parser" in factory_names: if "parser" in factory_names:
aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj) aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj)
data["deps"].update([x for x in aligned_deps if x is not None]) data["deps"].update([x for x in aligned_deps if x is not None])
@ -584,8 +660,8 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
return count return count
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Sequence[str]: def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
if pipe_name not in nlp.pipe_names: if pipe_name not in nlp.pipe_names:
return set() return set()
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
return pipe.labels return set(pipe.labels)