From d2b9aafb8c8d91ea74c2418d9fb32f1ce8812bbf Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 5 Oct 2020 14:14:49 +0200 Subject: [PATCH] Fix augmenter --- spacy/training/augment.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/spacy/training/augment.py b/spacy/training/augment.py index e6d10a195..06656bdd8 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -120,8 +120,8 @@ def make_orth_variants( ndsv = orth_variants.get("single", []) ndpv = orth_variants.get("paired", []) logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants") - words = token_dict.get("words", []) - tags = token_dict.get("tags", []) + words = token_dict.get("ORTH", []) + tags = token_dict.get("TAG", []) # keep unmodified if words or tags are not defined if words and tags: if lower: @@ -131,7 +131,7 @@ def make_orth_variants( for word_idx in range(len(words)): for punct_idx in range(len(ndsv)): if ( - tags[word_idx] in ndsv[punct_idx]["tags"] + tags[word_idx] in ndsv[punct_idx]["TAG"] and words[word_idx] in ndsv[punct_idx]["variants"] ): words[word_idx] = punct_choices[punct_idx] @@ -139,14 +139,14 @@ def make_orth_variants( punct_choices = [random.choice(x["variants"]) for x in ndpv] for word_idx in range(len(words)): for punct_idx in range(len(ndpv)): - if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ + if tags[word_idx] in ndpv[punct_idx]["TAG"] and words[ word_idx ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): # backup option: random left vs. right from pair pair_idx = random.choice([0, 1]) # best option: rely on paired POS tags like `` / '' - if len(ndpv[punct_idx]["tags"]) == 2: - pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) + if len(ndpv[punct_idx]["TAG"]) == 2: + pair_idx = ndpv[punct_idx]["TAG"].index(tags[word_idx]) # next best option: rely on position in variants # (may not be unambiguous, so order of variants matters) else: @@ -154,8 +154,8 @@ def make_orth_variants( if words[word_idx] in pair: pair_idx = pair.index(words[word_idx]) words[word_idx] = punct_choices[punct_idx][pair_idx] - token_dict["words"] = words - token_dict["tags"] = tags + token_dict["ORTH"] = words + token_dict["TAG"] = tags # modify raw if raw is not None: variants = []