- improved filtering of paraphrasing dataset

- better normalization during generation for punctuation and special tokens
- normalization for cased paraphrasing models
This commit is contained in:
Sina 2020-04-18 19:06:32 -07:00
parent b0a0398576
commit eeca171e46
3 changed files with 234 additions and 8 deletions

View File

@ -0,0 +1,88 @@
from argparse import ArgumentParser
import csv
import sys
from tqdm import tqdm
from genienlp.util import detokenize, get_number_of_lines
csv.field_size_limit(sys.maxsize)
def is_english(s):
try:
s.encode(encoding='utf-8').decode('ascii')
except UnicodeDecodeError:
return False
else:
return True
def remove_quotation(s):
s = s.replace('``', '')
s = s.replace('\'\'', '')
s = s.replace('"', '')
if s.startswith('\''):
s = s[1:]
if s.endswith('\''):
s = s[:-1]
return s
def is_valid(s):
return 'http' not in s and s.count('-') <= 4 and s.count('.') <= 4 and is_english(s) \
and '_' not in s and '%' not in s and '/' not in s and '*' not in s and '\\' not in s \
and 'www' not in s and sum(c.isdigit() for c in s) <= 10 and s.count('(') == s.count(')')
def main():
parser = ArgumentParser()
parser.add_argument('input', type=str,
help='The path to the input .tsv file.')
parser.add_argument('output', type=str,
help='The path to the output .txt file.')
# By default, we swap the columns so that the target of paraphrasing will be a grammatically correct sentence, i.e. written by a human, not an NMT
parser.add_argument('--first_column', type=int, default=1, help='The column index in the input file to put in the first column of the output file')
parser.add_argument('--second_column', type=int, default=0, help='The column index in the input file to put in the second column of the output file')
parser.add_argument('--min_length', type=int, default=30, help='Minimum number of characters that each phrase should have in order to be included')
parser.add_argument('--max_length', type=int, default=150, help='Maximum number of characters that each phrase should have in order to be included')
parser.add_argument('--skip_check', action='store_true', help='Skip validity check.')
parser.add_argument('--skip_normalization', action='store_true', help='Do not remove quotation marks or detokenize.')
parser.add_argument('--lower_case', action='store_true', help='Convert everything to lower case.')
parser.add_argument('--max_output_size', type=int, default=1e10, help='Maximum number of examples in the output.')
args = parser.parse_args()
drop_count = 0
# number_of_lines = get_number_of_lines(args.input)
# number_of_lines = get_number_of_lines(args.input)
output_size = 0
with open(args.input, 'r') as input_file, open(args.output, 'w') as output_file:
writer = csv.writer(output_file, delimiter='\t')
reader = csv.reader(input_file, delimiter='\t')
for row in tqdm(reader, desc='Lines'):
first = row[args.first_column] # input sequence
second = row[args.second_column] # output_sequence
# print(first)
# print(second)
if not args.skip_check and \
(len(first) < args.min_length or len(second) < args.min_length \
or len(first) > args.max_length or len(second) > args.max_length \
or not is_valid(first) or not is_valid(second)):
drop_count += 1
continue
if not args.skip_normalization:
first = remove_quotation(detokenize(first))
second = remove_quotation(detokenize(second))
first = first.strip()
second = second.strip()
if args.lower_case:
first = first.lower()
second = second.lower()
if first.lower() == second.lower() or first == '' or second == '':
drop_count += 1
continue
writer.writerow([first, second])
output_size += 1
if output_size >= args.max_output_size:
break
print('Dropped', drop_count, 'examples')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,32 @@
from argparse import ArgumentParser
from tqdm import tqdm
import random
def main():
parser = ArgumentParser()
parser.add_argument('input', type=str,
help='The path to the input.')
parser.add_argument('output1', type=str,
help='The path to the output train file.')
parser.add_argument('output2', type=str,
help='The path to the output dev file.')
parser.add_argument('--output1_ratio', type=float, required=True,
help='The ratio of input examples that go to output1')
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
args = parser.parse_args()
random.seed(args.seed)
with open(args.input, 'r') as input_file, open(args.output1, 'w') as output_file1, open(args.output2, 'w') as output_file2:
for line in tqdm(input_file):
r = random.random()
if r < args.output1_ratio:
output_file1.write(line)
else:
output_file2.write(line)
if __name__ == '__main__':
main()

View File

@ -45,18 +45,124 @@ from .data_utils.iterator import Iterator
logger = logging.getLogger(__name__)
class SpecialTokenMap:
def __init__(self, pattern, forward_func, backward_func=None):
"""
Inputs:
pattern: a regex pattern
forward_func: a function with signature forward_func(str) -> str
backward_func: a function with signature backward_func(str) -> list[str]
"""
if isinstance(forward_func, list):
self.forward_func = lambda x: forward_func[int(x)%len(forward_func)]
else:
self.forward_func = forward_func
if isinstance(backward_func, list):
self.backward_func = lambda x: backward_func[int(x)%len(backward_func)]
else:
self.backward_func = backward_func
self.pattern = pattern
def forwad(self, s: str):
reverse_map = []
matches = re.finditer(self.pattern, s)
if matches is None:
return s, reverse_map
for match in matches:
occurance = match.group(0)
# print('occurance = ', occurance)
parameter = match.group(1)
replacement = self.forward_func(parameter)
s = s.replace(occurance, replacement)
reverse_map.append((self, occurance))
return s, reverse_map
def backward(self, s: str, occurance: str):
match = re.match(self.pattern, occurance)
parameter = match.group(1)
if self.backward_func is None:
list_of_strings_to_match = [self.forward_func(parameter)]
else:
list_of_strings_to_match = sorted(self.backward_func(parameter), key=lambda x:len(x), reverse=True)
# print('list_of_strings_to_match = ', list_of_strings_to_match)
for string_to_match in list_of_strings_to_match:
l = [' '+string_to_match+' ', string_to_match+' ', ' '+string_to_match]
o = [' '+occurance+' ', occurance+' ', ' '+occurance]
new_s = s
for i in range(len(l)):
new_s = re.sub(l[i], o[i], s, flags=re.IGNORECASE)
if s != new_s:
break
if s != new_s:
s = new_s
break
return s
def tokenizer(s):
return s.split()
def detokenize(text):
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "'s", ")"]
def mask_special_tokens(string: str):
exceptions = [match.group(0) for match in re.finditer('[A-Za-z:_.]+_[0-9]+', string)]
for e in exceptions:
string = string.replace(e, '<temp>', 1)
return string, exceptions
def unmask_special_tokens(string: str, exceptions: list):
for e in exceptions:
string = string.replace('<temp>', e, 1)
return string
def detokenize(string: str):
string, exceptions = mask_special_tokens(string)
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
for t in tokens:
text = text.replace(' ' + t, t)
text = text.replace("( ", "(")
text = text.replace('gon na', 'gonna')
text = text.replace('wan na', 'wanna')
return text
string = string.replace(' ' + t, t)
string = string.replace("( ", "(")
string = string.replace('gon na', 'gonna')
string = string.replace('wan na', 'wanna')
string = unmask_special_tokens(string, exceptions)
return string
def tokenize(string: str):
string, exceptions = mask_special_tokens(string)
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
for t in tokens:
string = string.replace(t, ' ' + t)
string = string.replace("(", "( ")
string = string.replace('gonna', 'gon na')
string = string.replace('wanna', 'wan na')
string = re.sub('\s+', ' ', string)
string = unmask_special_tokens(string, exceptions)
return string.strip()
def lower_case(string):
string, exceptions = mask_special_tokens(string)
string = string.lower()
string = unmask_special_tokens(string, exceptions)
return string
def remove_thingtalk_quotes(thingtalk):
quote_values = []
while True:
# print('before: ', thingtalk)
l1 = thingtalk.find('"')
if l1 < 0:
break
l2 = thingtalk.find('"', l1+1)
if l2 < 0:
# ThingTalk code is not syntactic
return thingtalk, None
quote_values.append(thingtalk[l1+1: l2].strip())
thingtalk = thingtalk[:l1] + '<temp>' + thingtalk[l2+1:]
# print('after: ', thingtalk)
thingtalk = thingtalk.replace('<temp>', '""')
return thingtalk, quote_values
def get_number_of_lines(file_path):
count = 0
with open(file_path) as f: