mirror of https://github.com/explosion/spaCy.git
Add test for custom data augmentation
This commit is contained in:
parent
3856048437
commit
c41a4332e4
|
@ -7,11 +7,11 @@ from spacy.training.converters import json_to_docs
|
|||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.util import get_words_and_spaces, minibatch
|
||||
from thinc.api import compounding
|
||||
import pytest
|
||||
import srsly
|
||||
import random
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -515,6 +515,39 @@ def test_make_orth_variants(doc):
|
|||
list(reader(nlp))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_custom_data_augmentation(doc):
|
||||
def create_spongebob_augmenter(randomize: bool = False):
|
||||
def augment(nlp, example):
|
||||
text = example.text
|
||||
if randomize:
|
||||
ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text]
|
||||
else:
|
||||
ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)]
|
||||
example_dict = example.to_dict()
|
||||
doc = nlp.make_doc("".join(ch))
|
||||
example_dict["token_annotation"]["ORTH"] = [t.text for t in doc]
|
||||
yield example
|
||||
yield example.from_dict(doc, example_dict)
|
||||
|
||||
return augment
|
||||
|
||||
nlp = English()
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
DocBin(docs=[doc]).to_disk(output_file)
|
||||
reader = Corpus(output_file, augmenter=create_spongebob_augmenter())
|
||||
corpus = list(reader(nlp))
|
||||
orig_text = "Sarah 's sister flew to Silicon Valley via London . "
|
||||
augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . "
|
||||
assert corpus[0].text == orig_text
|
||||
assert corpus[0].reference.text == orig_text
|
||||
assert corpus[0].predicted.text == orig_text
|
||||
assert corpus[1].text == augmented
|
||||
assert corpus[1].reference.text == augmented
|
||||
assert corpus[1].predicted.text == augmented
|
||||
|
||||
|
||||
@pytest.mark.skip("Outdated")
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_a,tokens_b,expected",
|
||||
|
|
Loading…
Reference in New Issue