- improved filtering of paraphrasing dataset
- better normalization during generation for punctuation and special tokens - normalization for cased paraphrasing models
This commit is contained in:
parent
b0a0398576
commit
eeca171e46
|
@ -0,0 +1,88 @@
|
|||
from argparse import ArgumentParser
|
||||
import csv
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
from genienlp.util import detokenize, get_number_of_lines
|
||||
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
|
||||
def is_english(s):
|
||||
try:
|
||||
s.encode(encoding='utf-8').decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def remove_quotation(s):
|
||||
s = s.replace('``', '')
|
||||
s = s.replace('\'\'', '')
|
||||
s = s.replace('"', '')
|
||||
if s.startswith('\''):
|
||||
s = s[1:]
|
||||
if s.endswith('\''):
|
||||
s = s[:-1]
|
||||
return s
|
||||
|
||||
def is_valid(s):
|
||||
return 'http' not in s and s.count('-') <= 4 and s.count('.') <= 4 and is_english(s) \
|
||||
and '_' not in s and '%' not in s and '/' not in s and '*' not in s and '\\' not in s \
|
||||
and 'www' not in s and sum(c.isdigit() for c in s) <= 10 and s.count('(') == s.count(')')
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('input', type=str,
|
||||
help='The path to the input .tsv file.')
|
||||
parser.add_argument('output', type=str,
|
||||
help='The path to the output .txt file.')
|
||||
|
||||
# By default, we swap the columns so that the target of paraphrasing will be a grammatically correct sentence, i.e. written by a human, not an NMT
|
||||
parser.add_argument('--first_column', type=int, default=1, help='The column index in the input file to put in the first column of the output file')
|
||||
parser.add_argument('--second_column', type=int, default=0, help='The column index in the input file to put in the second column of the output file')
|
||||
|
||||
parser.add_argument('--min_length', type=int, default=30, help='Minimum number of characters that each phrase should have in order to be included')
|
||||
parser.add_argument('--max_length', type=int, default=150, help='Maximum number of characters that each phrase should have in order to be included')
|
||||
parser.add_argument('--skip_check', action='store_true', help='Skip validity check.')
|
||||
parser.add_argument('--skip_normalization', action='store_true', help='Do not remove quotation marks or detokenize.')
|
||||
parser.add_argument('--lower_case', action='store_true', help='Convert everything to lower case.')
|
||||
parser.add_argument('--max_output_size', type=int, default=1e10, help='Maximum number of examples in the output.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
drop_count = 0
|
||||
# number_of_lines = get_number_of_lines(args.input)
|
||||
# number_of_lines = get_number_of_lines(args.input)
|
||||
output_size = 0
|
||||
with open(args.input, 'r') as input_file, open(args.output, 'w') as output_file:
|
||||
writer = csv.writer(output_file, delimiter='\t')
|
||||
reader = csv.reader(input_file, delimiter='\t')
|
||||
for row in tqdm(reader, desc='Lines'):
|
||||
first = row[args.first_column] # input sequence
|
||||
second = row[args.second_column] # output_sequence
|
||||
# print(first)
|
||||
# print(second)
|
||||
if not args.skip_check and \
|
||||
(len(first) < args.min_length or len(second) < args.min_length \
|
||||
or len(first) > args.max_length or len(second) > args.max_length \
|
||||
or not is_valid(first) or not is_valid(second)):
|
||||
drop_count += 1
|
||||
continue
|
||||
if not args.skip_normalization:
|
||||
first = remove_quotation(detokenize(first))
|
||||
second = remove_quotation(detokenize(second))
|
||||
first = first.strip()
|
||||
second = second.strip()
|
||||
if args.lower_case:
|
||||
first = first.lower()
|
||||
second = second.lower()
|
||||
if first.lower() == second.lower() or first == '' or second == '':
|
||||
drop_count += 1
|
||||
continue
|
||||
writer.writerow([first, second])
|
||||
output_size += 1
|
||||
if output_size >= args.max_output_size:
|
||||
break
|
||||
print('Dropped', drop_count, 'examples')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,32 @@
|
|||
from argparse import ArgumentParser
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('input', type=str,
|
||||
help='The path to the input.')
|
||||
parser.add_argument('output1', type=str,
|
||||
help='The path to the output train file.')
|
||||
parser.add_argument('output2', type=str,
|
||||
help='The path to the output dev file.')
|
||||
parser.add_argument('--output1_ratio', type=float, required=True,
|
||||
help='The ratio of input examples that go to output1')
|
||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
random.seed(args.seed)
|
||||
|
||||
with open(args.input, 'r') as input_file, open(args.output1, 'w') as output_file1, open(args.output2, 'w') as output_file2:
|
||||
for line in tqdm(input_file):
|
||||
r = random.random()
|
||||
if r < args.output1_ratio:
|
||||
output_file1.write(line)
|
||||
else:
|
||||
output_file2.write(line)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
122
genienlp/util.py
122
genienlp/util.py
|
@ -45,18 +45,124 @@ from .data_utils.iterator import Iterator
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecialTokenMap:
|
||||
def __init__(self, pattern, forward_func, backward_func=None):
|
||||
"""
|
||||
Inputs:
|
||||
pattern: a regex pattern
|
||||
forward_func: a function with signature forward_func(str) -> str
|
||||
backward_func: a function with signature backward_func(str) -> list[str]
|
||||
"""
|
||||
if isinstance(forward_func, list):
|
||||
self.forward_func = lambda x: forward_func[int(x)%len(forward_func)]
|
||||
else:
|
||||
self.forward_func = forward_func
|
||||
|
||||
if isinstance(backward_func, list):
|
||||
self.backward_func = lambda x: backward_func[int(x)%len(backward_func)]
|
||||
else:
|
||||
self.backward_func = backward_func
|
||||
|
||||
self.pattern = pattern
|
||||
|
||||
def forwad(self, s: str):
|
||||
reverse_map = []
|
||||
matches = re.finditer(self.pattern, s)
|
||||
if matches is None:
|
||||
return s, reverse_map
|
||||
for match in matches:
|
||||
occurance = match.group(0)
|
||||
# print('occurance = ', occurance)
|
||||
parameter = match.group(1)
|
||||
replacement = self.forward_func(parameter)
|
||||
s = s.replace(occurance, replacement)
|
||||
reverse_map.append((self, occurance))
|
||||
return s, reverse_map
|
||||
|
||||
def backward(self, s: str, occurance: str):
|
||||
match = re.match(self.pattern, occurance)
|
||||
parameter = match.group(1)
|
||||
if self.backward_func is None:
|
||||
list_of_strings_to_match = [self.forward_func(parameter)]
|
||||
else:
|
||||
list_of_strings_to_match = sorted(self.backward_func(parameter), key=lambda x:len(x), reverse=True)
|
||||
# print('list_of_strings_to_match = ', list_of_strings_to_match)
|
||||
for string_to_match in list_of_strings_to_match:
|
||||
l = [' '+string_to_match+' ', string_to_match+' ', ' '+string_to_match]
|
||||
o = [' '+occurance+' ', occurance+' ', ' '+occurance]
|
||||
new_s = s
|
||||
for i in range(len(l)):
|
||||
new_s = re.sub(l[i], o[i], s, flags=re.IGNORECASE)
|
||||
if s != new_s:
|
||||
break
|
||||
if s != new_s:
|
||||
s = new_s
|
||||
break
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def tokenizer(s):
|
||||
return s.split()
|
||||
|
||||
def detokenize(text):
|
||||
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "'s", ")"]
|
||||
def mask_special_tokens(string: str):
|
||||
exceptions = [match.group(0) for match in re.finditer('[A-Za-z:_.]+_[0-9]+', string)]
|
||||
for e in exceptions:
|
||||
string = string.replace(e, '<temp>', 1)
|
||||
return string, exceptions
|
||||
|
||||
def unmask_special_tokens(string: str, exceptions: list):
|
||||
for e in exceptions:
|
||||
string = string.replace('<temp>', e, 1)
|
||||
return string
|
||||
|
||||
|
||||
def detokenize(string: str):
|
||||
string, exceptions = mask_special_tokens(string)
|
||||
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
|
||||
for t in tokens:
|
||||
text = text.replace(' ' + t, t)
|
||||
text = text.replace("( ", "(")
|
||||
text = text.replace('gon na', 'gonna')
|
||||
text = text.replace('wan na', 'wanna')
|
||||
return text
|
||||
|
||||
string = string.replace(' ' + t, t)
|
||||
string = string.replace("( ", "(")
|
||||
string = string.replace('gon na', 'gonna')
|
||||
string = string.replace('wan na', 'wanna')
|
||||
string = unmask_special_tokens(string, exceptions)
|
||||
return string
|
||||
|
||||
def tokenize(string: str):
|
||||
string, exceptions = mask_special_tokens(string)
|
||||
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
|
||||
for t in tokens:
|
||||
string = string.replace(t, ' ' + t)
|
||||
string = string.replace("(", "( ")
|
||||
string = string.replace('gonna', 'gon na')
|
||||
string = string.replace('wanna', 'wan na')
|
||||
string = re.sub('\s+', ' ', string)
|
||||
string = unmask_special_tokens(string, exceptions)
|
||||
return string.strip()
|
||||
|
||||
def lower_case(string):
|
||||
string, exceptions = mask_special_tokens(string)
|
||||
string = string.lower()
|
||||
string = unmask_special_tokens(string, exceptions)
|
||||
return string
|
||||
|
||||
def remove_thingtalk_quotes(thingtalk):
|
||||
quote_values = []
|
||||
while True:
|
||||
# print('before: ', thingtalk)
|
||||
l1 = thingtalk.find('"')
|
||||
if l1 < 0:
|
||||
break
|
||||
l2 = thingtalk.find('"', l1+1)
|
||||
if l2 < 0:
|
||||
# ThingTalk code is not syntactic
|
||||
return thingtalk, None
|
||||
quote_values.append(thingtalk[l1+1: l2].strip())
|
||||
thingtalk = thingtalk[:l1] + '<temp>' + thingtalk[l2+1:]
|
||||
# print('after: ', thingtalk)
|
||||
thingtalk = thingtalk.replace('<temp>', '""')
|
||||
return thingtalk, quote_values
|
||||
|
||||
def get_number_of_lines(file_path):
|
||||
count = 0
|
||||
with open(file_path) as f:
|
||||
|
|
Loading…
Reference in New Issue