Fix augmenter

This commit is contained in:
Matthew Honnibal 2020-10-05 14:14:49 +02:00
parent 549758f67d
commit d2b9aafb8c
1 changed files with 8 additions and 8 deletions

View File

@ -120,8 +120,8 @@ def make_orth_variants(
ndsv = orth_variants.get("single", [])
ndpv = orth_variants.get("paired", [])
logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants")
words = token_dict.get("words", [])
tags = token_dict.get("tags", [])
words = token_dict.get("ORTH", [])
tags = token_dict.get("TAG", [])
# keep unmodified if words or tags are not defined
if words and tags:
if lower:
@ -131,7 +131,7 @@ def make_orth_variants(
for word_idx in range(len(words)):
for punct_idx in range(len(ndsv)):
if (
tags[word_idx] in ndsv[punct_idx]["tags"]
tags[word_idx] in ndsv[punct_idx]["TAG"]
and words[word_idx] in ndsv[punct_idx]["variants"]
):
words[word_idx] = punct_choices[punct_idx]
@ -139,14 +139,14 @@ def make_orth_variants(
punct_choices = [random.choice(x["variants"]) for x in ndpv]
for word_idx in range(len(words)):
for punct_idx in range(len(ndpv)):
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
if tags[word_idx] in ndpv[punct_idx]["TAG"] and words[
word_idx
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
# backup option: random left vs. right from pair
pair_idx = random.choice([0, 1])
# best option: rely on paired POS tags like `` / ''
if len(ndpv[punct_idx]["tags"]) == 2:
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
if len(ndpv[punct_idx]["TAG"]) == 2:
pair_idx = ndpv[punct_idx]["TAG"].index(tags[word_idx])
# next best option: rely on position in variants
# (may not be unambiguous, so order of variants matters)
else:
@ -154,8 +154,8 @@ def make_orth_variants(
if words[word_idx] in pair:
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
token_dict["words"] = words
token_dict["tags"] = tags
token_dict["ORTH"] = words
token_dict["TAG"] = tags
# modify raw
if raw is not None:
variants = []