spaCy/spacy/tests/training/test_augmenters.py

242 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
from spacy.pipeline._parser_internals.nonproj import contains_cycle
from spacy.training import Corpus, Example
from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter
from spacy.training.augment import make_whitespace_variant
from spacy.lang.en import English
from spacy.tokens import DocBin, Doc, Span
from contextlib import contextmanager
import random
from ..util import make_tempdir
@contextmanager
def make_docbin(docs, name="roundtrip.spacy"):
with make_tempdir() as tmpdir:
output_file = tmpdir / name
DocBin(docs=docs).to_disk(output_file)
yield output_file
@pytest.fixture
def nlp():
return English()
@pytest.fixture
def doc(nlp):
# fmt: off
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
ents = ["B-PERSON", "I-PERSON", "O", "", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
# fmt: on
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
doc.cats = cats
return doc
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_make_orth_variants(nlp):
single = [
{"tags": ["NFP"], "variants": ["", "..."]},
{"tags": [":"], "variants": ["-", "", "", "--", "---", "——"]},
]
# fmt: off
words = ["\n\n", "A", "\t", "B", "a", "b", "", "...", "-", "", "", "--", "---", "——"]
tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
# fmt: on
spaces = [True] * len(words)
spaces[0] = False
spaces[2] = False
doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
augmenter = create_orth_variants_augmenter(
level=0.2, lower=0.5, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
# Due to randomness, only test that it works without errors
list(reader(nlp))
# check that the following settings lowercase everything
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for token in example.reference:
assert token.text == token.text.lower()
# check that lowercasing is applied without tags
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=1.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text.lower()
# check that no lowercasing is applied with lower=0.0
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
augmenter = create_orth_variants_augmenter(
level=1.0, lower=0.0, orth_variants={"single": single}
)
with make_docbin([doc] * 10) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
for example in reader(nlp):
for ex_token, doc_token in zip(example.reference, doc):
assert ex_token.text == doc_token.text
def test_lowercase_augmenter(nlp, doc):
augmenter = create_lower_casing_augmenter(level=1.0)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
assert ref_ent.text == orig_ent.text.lower()
assert [t.ent_iob for t in doc] == [t.ent_iob for t in eg.reference]
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
# check that augmentation works when lowercasing leads to different
# predicted tokenization
words = ["A", "B", "CCC."]
doc = Doc(nlp.vocab, words=words)
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg = corpus[0]
assert eg.reference.text == doc.text.lower()
assert eg.predicted.text == doc.text.lower()
assert [t.text for t in eg.reference] == [t.lower() for t in words]
assert [t.text for t in eg.predicted] == [
t.text for t in nlp.make_doc(doc.text.lower())
]
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc):
def create_spongebob_augmenter(randomize: bool = False):
def augment(nlp, example):
text = example.text
if randomize:
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
else:
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
example_dict = example.to_dict()
doc = nlp.make_doc("".join(ch))
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
yield example
yield example.from_dict(doc, example_dict)
return augment
with make_docbin([doc]) as output_file:
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
corpus = list(reader(nlp))
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
assert corpus[0].text == orig_text
assert corpus[0].reference.text == orig_text
assert corpus[0].predicted.text == orig_text
assert corpus[1].text == augmented
assert corpus[1].reference.text == augmented
assert corpus[1].predicted.text == augmented
ents = [(e.start, e.end, e.label) for e in doc.ents]
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
def test_make_whitespace_variant(nlp):
# fmt: off
text = "They flew to New York City.\nThen they drove to Washington, D.C."
words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."]
spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False]
tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"]
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
ents = ["O", "", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
# fmt: on
doc = Doc(
nlp.vocab,
words=words,
spaces=spaces,
tags=tags,
lemmas=lemmas,
heads=heads,
deps=deps,
ents=ents,
)
assert doc.text == text
example = Example(nlp.make_doc(text), doc)
# whitespace is only added internally in entity spans
mod_ex = make_whitespace_variant(nlp, example, " ", 3)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 4)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.reference.ents[0].text == "New York City"
mod_ex = make_whitespace_variant(nlp, example, " ", 6)
assert mod_ex.reference.ents[0].text == "New York City"
# add a space at every possible position
for i in range(len(doc) + 1):
mod_ex = make_whitespace_variant(nlp, example, " ", i)
assert mod_ex.reference[i].is_space
# adds annotation when the doc contains at least partial annotation
assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:]
assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:]
assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:]
# does not add partial annotation if doc does not contain this feature
assert not mod_ex.reference.has_annotation("POS")
assert not mod_ex.reference.has_annotation("MORPH")
# produces well-formed trees
assert not contains_cycle([t.head.i for t in mod_ex.reference])
assert len(list(doc.sents)) == 2
if i == 0:
assert mod_ex.reference[i].head.i == 1
else:
assert mod_ex.reference[i].head.i == i - 1
# adding another space also produces well-formed trees
for j in (3, 8, 10):
mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j)
assert not contains_cycle([t.head.i for t in mod_ex2.reference])
assert len(list(doc.sents)) == 2
assert mod_ex2.reference[j].head.i == j - 1
# entities are well-formed
assert len(doc.ents) == len(mod_ex.reference.ents)
# there is one token with missing entity information
assert any(t.ent_iob == 0 for t in mod_ex.reference)
for ent in mod_ex.reference.ents:
assert not ent[0].is_space
assert not ent[-1].is_space
# no modifications if:
# partial dependencies
example.reference[0].dep_ = ""
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
example.reference[0].dep_ = "nsubj" # reset
# spans
example.reference.spans["spans"] = [example.reference[0:5]]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text
del example.reference.spans["spans"] # reset
# links
example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")]
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
assert mod_ex.text == example.reference.text