Fix augmenter

This commit is contained in:
Matthew Honnibal 2020-10-05 14:14:49 +02:00
parent 549758f67d
commit d2b9aafb8c
1 changed files with 8 additions and 8 deletions

View File

@ -120,8 +120,8 @@ def make_orth_variants(
ndsv = orth_variants.get("single", []) ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", []) ndpv = orth_variants.get("paired", [])
logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants") logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants")
words = token_dict.get("words", []) words = token_dict.get("ORTH", [])
tags = token_dict.get("tags", []) tags = token_dict.get("TAG", [])
# keep unmodified if words or tags are not defined # keep unmodified if words or tags are not defined
if words and tags: if words and tags:
if lower: if lower:
@ -131,7 +131,7 @@ def make_orth_variants(
for word_idx in range(len(words)): for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)): for punct_idx in range(len(ndsv)):
if ( if (
tags[word_idx] in ndsv[punct_idx]["tags"] tags[word_idx] in ndsv[punct_idx]["TAG"]
and words[word_idx] in ndsv[punct_idx]["variants"] and words[word_idx] in ndsv[punct_idx]["variants"]
): ):
words[word_idx] = punct_choices[punct_idx] words[word_idx] = punct_choices[punct_idx]
@ -139,14 +139,14 @@ def make_orth_variants(
punct_choices = [random.choice(x["variants"]) for x in ndpv] punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)): for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)): for punct_idx in range(len(ndpv)):
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ if tags[word_idx] in ndpv[punct_idx]["TAG"] and words[
word_idx word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
# backup option: random left vs. right from pair # backup option: random left vs. right from pair
pair_idx = random.choice([0, 1]) pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / '' # best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2: if len(ndpv[punct_idx]["TAG"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) pair_idx = ndpv[punct_idx]["TAG"].index(tags[word_idx])
# next best option: rely on position in variants # next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters) # (may not be unambiguous, so order of variants matters)
else: else:
@ -154,8 +154,8 @@ def make_orth_variants(
if words[word_idx] in pair: if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx]) pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx] words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["words"] = words token_dict["ORTH"] = words
token_dict["tags"] = tags token_dict["TAG"] = tags
# modify raw # modify raw
if raw is not None: if raw is not None:
variants = [] variants = []