From 42072f4468b353d785214a82b67ace38b728f9b5 Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Mon, 7 Feb 2022 22:03:36 +0800 Subject: [PATCH] Add spancat pipeline in spacy debug data (#10070) * Setup debug data for spancat * Add check for missing labels * Add low-level data warning error * Improve logic when compiling the gold train data * Implement check for negative examples * Remove breakpoint * Remove ws_ents and missing entity checks * Fix mypy errors * Make variable name spans_key consistent * Rename pipeline -> component for consistency * Account for missing labels per spans_key * Cleanup variable names for consistency * Improve brevity of conditional statements * Remove unused variables * Include spans_key as an argument for _get_examples * Add a conditional check for spans_key * Update spancat debug data based on new API - Instead of using _get_labels_from_model(), I'm now using _get_labels_from_spancat() (cf. https://github.com/explosion/spaCy/pull10079) - The way information is displayed was also changed (text -> table) * Rename model_labels to ensure mypy works * Update wording on warning messages Use "span type" instead of "entity type" in wording the warning messages. This is because Spans aren't necessarily entities. * Update component type into a Literal This is to make it clear that the component parameter should only accept either 'spancat' or 'ner'. * Update checks to include actual model span_keys Instead of looking at everything in the data, we only check those span_keys from the actual spancat component. Instead of doing the filter inside the for-loop, I just made another dictionary, data_labels_in_component to hold this value. * Update spacy/cli/debug_data.py * Show label counts only when verbose is True Co-authored-by: Adriane Boyd Co-authored-by: Adriane Boyd --- spacy/cli/debug_data.py | 102 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 7 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 4be749204..a63795148 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -193,6 +193,70 @@ def debug_data( else: msg.info("No word vectors present in the package") + if "spancat" in factory_names: + model_labels_spancat = _get_labels_from_spancat(nlp) + has_low_data_warning = False + has_no_neg_warning = False + + msg.divider("Span Categorization") + msg.table(model_labels_spancat, header=["Spans Key", "Labels"], divider=True) + + msg.text("Label counts in train data: ", show=verbose) + for spans_key, data_labels in gold_train_data["spancat"].items(): + msg.text( + f"Key: {spans_key}, {_format_labels(data_labels.items(), counts=True)}", + show=verbose, + ) + # Data checks: only take the spans keys in the actual spancat components + data_labels_in_component = { + spans_key: gold_train_data["spancat"][spans_key] + for spans_key in model_labels_spancat.keys() + } + for spans_key, data_labels in data_labels_in_component.items(): + for label, count in data_labels.items(): + # Check for missing labels + spans_key_in_model = spans_key in model_labels_spancat.keys() + if (spans_key_in_model) and ( + label not in model_labels_spancat[spans_key] + ): + msg.warn( + f"Label '{label}' is not present in the model labels of key '{spans_key}'. " + "Performance may degrade after training." + ) + # Check for low number of examples per label + if count <= NEW_LABEL_THRESHOLD: + msg.warn( + f"Low number of examples for label '{label}' in key '{spans_key}' ({count})" + ) + has_low_data_warning = True + # Check for negative examples + with msg.loading("Analyzing label distribution..."): + neg_docs = _get_examples_without_label( + train_dataset, label, "spancat", spans_key + ) + if neg_docs == 0: + msg.warn(f"No examples for texts WITHOUT new label '{label}'") + has_no_neg_warning = True + + if has_low_data_warning: + msg.text( + f"To train a new span type, your data should include at " + f"least {NEW_LABEL_THRESHOLD} instances of the new label", + show=verbose, + ) + else: + msg.good("Good amount of examples for all labels") + + if has_no_neg_warning: + msg.text( + "Training data should always include examples of spans " + "in context, as well as examples without a given span " + "type.", + show=verbose, + ) + else: + msg.good("Examples without ocurrences available for all labels") + if "ner" in factory_names: # Get all unique NER labels present in the data labels = set( @@ -238,7 +302,7 @@ def debug_data( has_low_data_warning = True with msg.loading("Analyzing label distribution..."): - neg_docs = _get_examples_without_label(train_dataset, label) + neg_docs = _get_examples_without_label(train_dataset, label, "ner") if neg_docs == 0: msg.warn(f"No examples for texts WITHOUT new label '{label}'") has_no_neg_warning = True @@ -573,6 +637,7 @@ def _compile_gold( "deps": Counter(), "words": Counter(), "roots": Counter(), + "spancat": dict(), "ws_ents": 0, "boundary_cross_ents": 0, "n_words": 0, @@ -617,6 +682,15 @@ def _compile_gold( data["boundary_cross_ents"] += 1 elif label == "-": data["ner"]["-"] += 1 + if "spancat" in factory_names: + for span_key in list(eg.reference.spans.keys()): + if span_key not in data["spancat"]: + data["spancat"][span_key] = Counter() + for i, span in enumerate(eg.reference.spans[span_key]): + if span.label_ is None: + continue + else: + data["spancat"][span_key][span.label_] += 1 if "textcat" in factory_names or "textcat_multilabel" in factory_names: data["cats"].update(gold.cats) if any(val not in (0, 1) for val in gold.cats.values()): @@ -687,14 +761,28 @@ def _format_labels( return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)]) -def _get_examples_without_label(data: Sequence[Example], label: str) -> int: +def _get_examples_without_label( + data: Sequence[Example], + label: str, + component: Literal["ner", "spancat"] = "ner", + spans_key: Optional[str] = "sc", +) -> int: count = 0 for eg in data: - labels = [ - label.split("-")[1] - for label in eg.get_aligned_ner() - if label not in ("O", "-", None) - ] + if component == "ner": + labels = [ + label.split("-")[1] + for label in eg.get_aligned_ner() + if label not in ("O", "-", None) + ] + + if component == "spancat": + labels = ( + [span.label_ for span in eg.reference.spans[spans_key]] + if spans_key in eg.reference.spans + else [] + ) + if label not in labels: count += 1 return count