From c41a4332e4f21627db1a7c5e057c3cfd70f5fea7 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 2 Oct 2020 11:37:56 +0200 Subject: [PATCH] Add test for custom data augmentation --- spacy/tests/training/test_training.py | 35 ++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index c53042ef1..7d41c8908 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -7,11 +7,11 @@ from spacy.training.converters import json_to_docs from spacy.training.augment import create_orth_variants_augmenter from spacy.lang.en import English from spacy.tokens import Doc, DocBin -from spacy.lookups import Lookups from spacy.util import get_words_and_spaces, minibatch from thinc.api import compounding import pytest import srsly +import random from ..util import make_tempdir @@ -515,6 +515,39 @@ def test_make_orth_variants(doc): list(reader(nlp)) +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_custom_data_augmentation(doc): + def create_spongebob_augmenter(randomize: bool = False): + def augment(nlp, example): + text = example.text + if randomize: + ch = [c.lower() if random.random() < 0.5 else c.upper() for c in text] + else: + ch = [c.lower() if i % 2 else c.upper() for i, c in enumerate(text)] + example_dict = example.to_dict() + doc = nlp.make_doc("".join(ch)) + example_dict["token_annotation"]["ORTH"] = [t.text for t in doc] + yield example + yield example.from_dict(doc, example_dict) + + return augment + + nlp = English() + with make_tempdir() as tmpdir: + output_file = tmpdir / "roundtrip.spacy" + DocBin(docs=[doc]).to_disk(output_file) + reader = Corpus(output_file, augmenter=create_spongebob_augmenter()) + corpus = list(reader(nlp)) + orig_text = "Sarah 's sister flew to Silicon Valley via London . " + augmented = "SaRaH 's sIsTeR FlEw tO SiLiCoN VaLlEy vIa lOnDoN . " + assert corpus[0].text == orig_text + assert corpus[0].reference.text == orig_text + assert corpus[0].predicted.text == orig_text + assert corpus[1].text == augmented + assert corpus[1].reference.text == augmented + assert corpus[1].predicted.text == augmented + + @pytest.mark.skip("Outdated") @pytest.mark.parametrize( "tokens_a,tokens_b,expected",