Add option of data augmentation noise

This commit is contained in:
Matthew Honnibal 2017-06-04 20:16:57 -05:00
parent 7b2ede783d
commit 9bc4a26213
1 changed files with 35 additions and 6 deletions

View File

@ -199,14 +199,16 @@ class GoldCorpus(object):
return n
def train_docs(self, nlp, gold_preproc=False,
projectivize=False, max_length=None):
projectivize=False, max_length=None,
noise_level=0.0):
train_tuples = self.train_tuples
if projectivize:
train_tuples = nonproj.preprocess_training_data(
self.train_tuples)
random.shuffle(train_tuples)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
max_length=max_length)
max_length=max_length,
noise_level=noise_level)
yield from gold_docs
def dev_docs(self, nlp, gold_preproc=False):
@ -215,7 +217,8 @@ class GoldCorpus(object):
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None):
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
noise_level=0.0):
for raw_text, paragraph_tuples in tuples:
if gold_preproc:
raw_text = None
@ -223,18 +226,20 @@ class GoldCorpus(object):
paragraph_tuples = merge_sents(paragraph_tuples)
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc)
gold_preproc, noise_level=noise_level)
golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
if (not max_length) or len(doc) < max_length:
yield doc, gold
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc):
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
noise_level=0.0):
if raw_text is not None:
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)]
else:
return [Doc(nlp.vocab, words=sent_tuples[1])
return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples]
@classmethod
@ -266,6 +271,30 @@ class GoldCorpus(object):
return locs
def add_noise(orig, noise_level):
if random.random() >= noise_level:
return orig
elif type(orig) == list:
corrupted = [_corrupt(word, noise_level) for word in orig]
corrupted = [w for w in corrupted if w]
return corrupted
else:
return ''.join(_corrupt(c, noise_level) for c in orig)
def _corrupt(c, noise_level):
if random.random() >= noise_level:
return c
elif c == ' ':
return '\n'
elif c == '\n':
return ' '
elif c in ['.', "'", "!", "?"]:
return ''
else:
return c.lower()
def read_json_file(loc, docs_filter=None, limit=None):
loc = ensure_path(loc)
if loc.is_dir():