From eeca171e46c2b3102e0f3483399dd49789a3c403 Mon Sep 17 00:00:00 2001 From: Sina Date: Sat, 18 Apr 2020 19:06:32 -0700 Subject: [PATCH] - improved filtering of paraphrasing dataset - better normalization during generation for punctuation and special tokens - normalization for cased paraphrasing models --- .../clean_paraphrasing_dataset.py | 88 +++++++++++++ .../split_dataset.py | 32 +++++ genienlp/util.py | 122 ++++++++++++++++-- 3 files changed, 234 insertions(+), 8 deletions(-) create mode 100644 genienlp/data_manipulation_scripts/clean_paraphrasing_dataset.py create mode 100644 genienlp/data_manipulation_scripts/split_dataset.py diff --git a/genienlp/data_manipulation_scripts/clean_paraphrasing_dataset.py b/genienlp/data_manipulation_scripts/clean_paraphrasing_dataset.py new file mode 100644 index 00000000..8b13a2b7 --- /dev/null +++ b/genienlp/data_manipulation_scripts/clean_paraphrasing_dataset.py @@ -0,0 +1,88 @@ +from argparse import ArgumentParser +import csv +import sys +from tqdm import tqdm +from genienlp.util import detokenize, get_number_of_lines + +csv.field_size_limit(sys.maxsize) + +def is_english(s): + try: + s.encode(encoding='utf-8').decode('ascii') + except UnicodeDecodeError: + return False + else: + return True + +def remove_quotation(s): + s = s.replace('``', '') + s = s.replace('\'\'', '') + s = s.replace('"', '') + if s.startswith('\''): + s = s[1:] + if s.endswith('\''): + s = s[:-1] + return s + +def is_valid(s): + return 'http' not in s and s.count('-') <= 4 and s.count('.') <= 4 and is_english(s) \ + and '_' not in s and '%' not in s and '/' not in s and '*' not in s and '\\' not in s \ + and 'www' not in s and sum(c.isdigit() for c in s) <= 10 and s.count('(') == s.count(')') + +def main(): + parser = ArgumentParser() + parser.add_argument('input', type=str, + help='The path to the input .tsv file.') + parser.add_argument('output', type=str, + help='The path to the output .txt file.') + + # By default, we swap the columns so that the target of paraphrasing will be a grammatically correct sentence, i.e. written by a human, not an NMT + parser.add_argument('--first_column', type=int, default=1, help='The column index in the input file to put in the first column of the output file') + parser.add_argument('--second_column', type=int, default=0, help='The column index in the input file to put in the second column of the output file') + + parser.add_argument('--min_length', type=int, default=30, help='Minimum number of characters that each phrase should have in order to be included') + parser.add_argument('--max_length', type=int, default=150, help='Maximum number of characters that each phrase should have in order to be included') + parser.add_argument('--skip_check', action='store_true', help='Skip validity check.') + parser.add_argument('--skip_normalization', action='store_true', help='Do not remove quotation marks or detokenize.') + parser.add_argument('--lower_case', action='store_true', help='Convert everything to lower case.') + parser.add_argument('--max_output_size', type=int, default=1e10, help='Maximum number of examples in the output.') + + args = parser.parse_args() + + drop_count = 0 + # number_of_lines = get_number_of_lines(args.input) + # number_of_lines = get_number_of_lines(args.input) + output_size = 0 + with open(args.input, 'r') as input_file, open(args.output, 'w') as output_file: + writer = csv.writer(output_file, delimiter='\t') + reader = csv.reader(input_file, delimiter='\t') + for row in tqdm(reader, desc='Lines'): + first = row[args.first_column] # input sequence + second = row[args.second_column] # output_sequence + # print(first) + # print(second) + if not args.skip_check and \ + (len(first) < args.min_length or len(second) < args.min_length \ + or len(first) > args.max_length or len(second) > args.max_length \ + or not is_valid(first) or not is_valid(second)): + drop_count += 1 + continue + if not args.skip_normalization: + first = remove_quotation(detokenize(first)) + second = remove_quotation(detokenize(second)) + first = first.strip() + second = second.strip() + if args.lower_case: + first = first.lower() + second = second.lower() + if first.lower() == second.lower() or first == '' or second == '': + drop_count += 1 + continue + writer.writerow([first, second]) + output_size += 1 + if output_size >= args.max_output_size: + break + print('Dropped', drop_count, 'examples') + +if __name__ == '__main__': + main() diff --git a/genienlp/data_manipulation_scripts/split_dataset.py b/genienlp/data_manipulation_scripts/split_dataset.py new file mode 100644 index 00000000..f28fddac --- /dev/null +++ b/genienlp/data_manipulation_scripts/split_dataset.py @@ -0,0 +1,32 @@ +from argparse import ArgumentParser +from tqdm import tqdm +import random + + +def main(): + parser = ArgumentParser() + parser.add_argument('input', type=str, + help='The path to the input.') + parser.add_argument('output1', type=str, + help='The path to the output train file.') + parser.add_argument('output2', type=str, + help='The path to the output dev file.') + parser.add_argument('--output1_ratio', type=float, required=True, + help='The ratio of input examples that go to output1') + parser.add_argument('--seed', default=123, type=int, help='Random seed.') + + + args = parser.parse_args() + random.seed(args.seed) + + with open(args.input, 'r') as input_file, open(args.output1, 'w') as output_file1, open(args.output2, 'w') as output_file2: + for line in tqdm(input_file): + r = random.random() + if r < args.output1_ratio: + output_file1.write(line) + else: + output_file2.write(line) + + +if __name__ == '__main__': + main() diff --git a/genienlp/util.py b/genienlp/util.py index 6ec7a177..ade8e7f4 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -45,18 +45,124 @@ from .data_utils.iterator import Iterator logger = logging.getLogger(__name__) +class SpecialTokenMap: + def __init__(self, pattern, forward_func, backward_func=None): + """ + Inputs: + pattern: a regex pattern + forward_func: a function with signature forward_func(str) -> str + backward_func: a function with signature backward_func(str) -> list[str] + """ + if isinstance(forward_func, list): + self.forward_func = lambda x: forward_func[int(x)%len(forward_func)] + else: + self.forward_func = forward_func + + if isinstance(backward_func, list): + self.backward_func = lambda x: backward_func[int(x)%len(backward_func)] + else: + self.backward_func = backward_func + + self.pattern = pattern + + def forwad(self, s: str): + reverse_map = [] + matches = re.finditer(self.pattern, s) + if matches is None: + return s, reverse_map + for match in matches: + occurance = match.group(0) + # print('occurance = ', occurance) + parameter = match.group(1) + replacement = self.forward_func(parameter) + s = s.replace(occurance, replacement) + reverse_map.append((self, occurance)) + return s, reverse_map + + def backward(self, s: str, occurance: str): + match = re.match(self.pattern, occurance) + parameter = match.group(1) + if self.backward_func is None: + list_of_strings_to_match = [self.forward_func(parameter)] + else: + list_of_strings_to_match = sorted(self.backward_func(parameter), key=lambda x:len(x), reverse=True) + # print('list_of_strings_to_match = ', list_of_strings_to_match) + for string_to_match in list_of_strings_to_match: + l = [' '+string_to_match+' ', string_to_match+' ', ' '+string_to_match] + o = [' '+occurance+' ', occurance+' ', ' '+occurance] + new_s = s + for i in range(len(l)): + new_s = re.sub(l[i], o[i], s, flags=re.IGNORECASE) + if s != new_s: + break + if s != new_s: + s = new_s + break + + return s + + def tokenizer(s): return s.split() -def detokenize(text): - tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "'s", ")"] +def mask_special_tokens(string: str): + exceptions = [match.group(0) for match in re.finditer('[A-Za-z:_.]+_[0-9]+', string)] + for e in exceptions: + string = string.replace(e, '', 1) + return string, exceptions + +def unmask_special_tokens(string: str, exceptions: list): + for e in exceptions: + string = string.replace('', e, 1) + return string + + +def detokenize(string: str): + string, exceptions = mask_special_tokens(string) + tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"] for t in tokens: - text = text.replace(' ' + t, t) - text = text.replace("( ", "(") - text = text.replace('gon na', 'gonna') - text = text.replace('wan na', 'wanna') - return text - + string = string.replace(' ' + t, t) + string = string.replace("( ", "(") + string = string.replace('gon na', 'gonna') + string = string.replace('wan na', 'wanna') + string = unmask_special_tokens(string, exceptions) + return string + +def tokenize(string: str): + string, exceptions = mask_special_tokens(string) + tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"] + for t in tokens: + string = string.replace(t, ' ' + t) + string = string.replace("(", "( ") + string = string.replace('gonna', 'gon na') + string = string.replace('wanna', 'wan na') + string = re.sub('\s+', ' ', string) + string = unmask_special_tokens(string, exceptions) + return string.strip() + +def lower_case(string): + string, exceptions = mask_special_tokens(string) + string = string.lower() + string = unmask_special_tokens(string, exceptions) + return string + +def remove_thingtalk_quotes(thingtalk): + quote_values = [] + while True: + # print('before: ', thingtalk) + l1 = thingtalk.find('"') + if l1 < 0: + break + l2 = thingtalk.find('"', l1+1) + if l2 < 0: + # ThingTalk code is not syntactic + return thingtalk, None + quote_values.append(thingtalk[l1+1: l2].strip()) + thingtalk = thingtalk[:l1] + '' + thingtalk[l2+1:] + # print('after: ', thingtalk) + thingtalk = thingtalk.replace('', '""') + return thingtalk, quote_values + def get_number_of_lines(file_path): count = 0 with open(file_path) as f: