mirror of https://github.com/explosion/spaCy.git
Fix augmenter
This commit is contained in:
parent
549758f67d
commit
d2b9aafb8c
|
@ -120,8 +120,8 @@ def make_orth_variants(
|
||||||
ndsv = orth_variants.get("single", [])
|
ndsv = orth_variants.get("single", [])
|
||||||
ndpv = orth_variants.get("paired", [])
|
ndpv = orth_variants.get("paired", [])
|
||||||
logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants")
|
logger.debug(f"Data augmentation: {len(ndsv)} single / {len(ndpv)} paired variants")
|
||||||
words = token_dict.get("words", [])
|
words = token_dict.get("ORTH", [])
|
||||||
tags = token_dict.get("tags", [])
|
tags = token_dict.get("TAG", [])
|
||||||
# keep unmodified if words or tags are not defined
|
# keep unmodified if words or tags are not defined
|
||||||
if words and tags:
|
if words and tags:
|
||||||
if lower:
|
if lower:
|
||||||
|
@ -131,7 +131,7 @@ def make_orth_variants(
|
||||||
for word_idx in range(len(words)):
|
for word_idx in range(len(words)):
|
||||||
for punct_idx in range(len(ndsv)):
|
for punct_idx in range(len(ndsv)):
|
||||||
if (
|
if (
|
||||||
tags[word_idx] in ndsv[punct_idx]["tags"]
|
tags[word_idx] in ndsv[punct_idx]["TAG"]
|
||||||
and words[word_idx] in ndsv[punct_idx]["variants"]
|
and words[word_idx] in ndsv[punct_idx]["variants"]
|
||||||
):
|
):
|
||||||
words[word_idx] = punct_choices[punct_idx]
|
words[word_idx] = punct_choices[punct_idx]
|
||||||
|
@ -139,14 +139,14 @@ def make_orth_variants(
|
||||||
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
||||||
for word_idx in range(len(words)):
|
for word_idx in range(len(words)):
|
||||||
for punct_idx in range(len(ndpv)):
|
for punct_idx in range(len(ndpv)):
|
||||||
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
|
if tags[word_idx] in ndpv[punct_idx]["TAG"] and words[
|
||||||
word_idx
|
word_idx
|
||||||
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
|
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
|
||||||
# backup option: random left vs. right from pair
|
# backup option: random left vs. right from pair
|
||||||
pair_idx = random.choice([0, 1])
|
pair_idx = random.choice([0, 1])
|
||||||
# best option: rely on paired POS tags like `` / ''
|
# best option: rely on paired POS tags like `` / ''
|
||||||
if len(ndpv[punct_idx]["tags"]) == 2:
|
if len(ndpv[punct_idx]["TAG"]) == 2:
|
||||||
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
|
pair_idx = ndpv[punct_idx]["TAG"].index(tags[word_idx])
|
||||||
# next best option: rely on position in variants
|
# next best option: rely on position in variants
|
||||||
# (may not be unambiguous, so order of variants matters)
|
# (may not be unambiguous, so order of variants matters)
|
||||||
else:
|
else:
|
||||||
|
@ -154,8 +154,8 @@ def make_orth_variants(
|
||||||
if words[word_idx] in pair:
|
if words[word_idx] in pair:
|
||||||
pair_idx = pair.index(words[word_idx])
|
pair_idx = pair.index(words[word_idx])
|
||||||
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
||||||
token_dict["words"] = words
|
token_dict["ORTH"] = words
|
||||||
token_dict["tags"] = tags
|
token_dict["TAG"] = tags
|
||||||
# modify raw
|
# modify raw
|
||||||
if raw is not None:
|
if raw is not None:
|
||||||
variants = []
|
variants = []
|
||||||
|
|
Loading…
Reference in New Issue