Merge pull request #97 from stanford-oval/wip/fix-para

Fix paraphrasing bugs
This commit is contained in:
s-jse 2021-02-18 17:11:33 -08:00 committed by GitHub
commit 69c5029f5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 13 deletions

View File

@ -28,8 +28,8 @@ special_pattern_mapping = [
['$13', 'thirteen dollars', '13 dollars', '$ 13', '$ 13.00', '13.00', '13']]),
SpecialTokenMap('DURATION_([0-9]+)', ['5 weeks', '6 weeks'], [['5 weeks', 'five weeks'], ['6 weeks', 'six weeks']]),
SpecialTokenMap('LOCATION_([0-9]+)', ['locatio1n', 'locatio2n'], [['locatio1n', 'locat1n'], ['locatio2n', 'locat2n']]),
SpecialTokenMap('QUOTED_STRING_([0-9]+)', lambda x: 'Chinese', lambda x: ['Chinese', 'chinese', 'china']), # TODO change to be more general than cuisine
SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
# SpecialTokenMap('QUOTED_STRING_([0-9]+)', ['Chinese', 'Italian'], [['Chinese', 'chinese', 'china'], ['Italian', 'italian']]), # TODO change to be more general than cuisine
# SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
]

View File

@ -4,16 +4,28 @@ import re
from ...util import tokenize, lower_case, remove_thingtalk_quotes
from ...data_utils.progbar import progress_bar
special_token_pattern = re.compile("(^|(?<= ))" + "[A-Z]+_[0-9]" + "($|(?= ))")
def find_special_tokens(s: str):
return list(sorted([a.group(0) for a in special_token_pattern.finditer(s)]))
def is_subset(set1, set2):
"""
Returns True if set1 is a subset of or equal to set2
"""
return all([e in set2 for e in set1])
def passes_heuristic_checks(row, args):
def passes_heuristic_checks(row, args, old_query=None):
if 'QUOTED_STRING' in row[args.utterance_column] or (old_query is not None and 'QUOTED_STRING' in old_query):
# remove quoted examples
return False
if old_query is not None:
# check that all the special tokens in utterance after paraphrasing are the same as before
if find_special_tokens(old_query) != find_special_tokens(row[args.utterance_column]):
return False
all_input_columns = ' '.join([row[c] for c in args.input_columns])
input_special_tokens = set(re.findall('[A-Za-z:_.]+_[0-9]', all_input_columns))
output_special_tokens = set(re.findall('[A-Za-z:_.]+_[0-9]', row[args.thingtalk_column]))
input_special_tokens = set(find_special_tokens(all_input_columns))
output_special_tokens = set(find_special_tokens(row[args.thingtalk_column]))
if not is_subset(output_special_tokens, input_special_tokens):
return False
_, quote_values = remove_thingtalk_quotes(row[args.thingtalk_column])
@ -110,6 +122,7 @@ def main(args):
seen_examples = set()
all_thrown_away_rows = []
for row_idx, row in enumerate(progress_bar(reader, desc='Lines')):
old_query = None
output_rows = []
thrown_away_rows = []
if args.transformation == 'remove_thingtalk_quotes':
@ -136,6 +149,7 @@ def main(args):
row[args.utterance_column] = new_query
output_rows.append(row)
elif args.transformation == 'replace_queries':
old_query = row[args.utterance_column]
for idx, new_query in enumerate(new_queries[row_idx]):
copy_row = row.copy()
copy_row[args.utterance_column] = new_query
@ -152,7 +166,7 @@ def main(args):
for o in output_rows:
output_row = ""
if args.remove_with_heuristics:
if not passes_heuristic_checks(o, args):
if not passes_heuristic_checks(o, args, old_query=old_query):
heuristic_count += 1
continue
if args.remove_duplicates:
@ -170,7 +184,7 @@ def main(args):
output_row += '\t'
output_file.write(output_row + '\n')
for o in thrown_away_rows:
if not args.remove_with_heuristics or (args.remove_with_heuristics and passes_heuristic_checks(o, args)):
if not args.remove_with_heuristics or (args.remove_with_heuristics and passes_heuristic_checks(o, args, old_query=old_query)):
all_thrown_away_rows.append(o)
if args.thrown_away is not None:

View File

@ -484,6 +484,83 @@ class NaturalSeq2Seq(BaseAlmondTask):
return Example.from_raw(self.name + '/' + example_id, context, question, answer,
preprocess=self.preprocess_field, lower=False)
def preprocess_field(self, sentence, field_name=None, answer=None):
if self.override_context is not None and field_name == 'context':
pad_feature = get_pad_feature(self.args.ned_features, self.args.ned_features_default_val, self.args.ned_features_size)
return self.override_context, [pad_feature] * len(self.override_context.split(' ')) if pad_feature else [], self.override_context
if self.override_question is not None and field_name == 'question':
pad_feature = get_pad_feature(self.args.ned_features, self.args.ned_features_default_val, self.args.ned_features_size)
return self.override_question, [pad_feature] * len(self.override_question.split(' ')) if pad_feature else [], self.override_question
if not sentence:
return '', [], ''
tokens = sentence.split(' ')
new_tokens = []
for token in tokens:
new_tokens.append(token)
tokens = new_tokens
new_sentence = ' '.join(tokens)
if self._almond_detokenize_sentence:
# BERT tokenizers by default add whitespace around any CJK character
# SPM-based tokenizers are trained on raw text and do better when recieve untokenized text
# In genienlp we detokenize CJK characters and leave tokenization to the model's tokenizer
# NOTE: input datasets for almond are usually pretokenized using genie-toolkit which
# inserts whitespace around any CJK character. This detokenization ensures that SPM-based tokenizers
# see the text without space between those characters
new_sentence = detokenize_cjk_chars(new_sentence)
tokens = new_sentence.split(' ')
new_sentence = ''
for token in tokens:
if token in (',', '.', '?', '!', ':', ')', ']', '}') or token.startswith("'"):
new_sentence += token
else:
new_sentence += ' ' + token
new_sentence = new_sentence.strip()
new_tokens = new_sentence.split(' ')
new_sentence_length = len(new_tokens)
tokens_type_ids, tokens_type_probs = None, None
if 'type_id' in self.args.ned_features and field_name != 'answer':
tokens_type_ids = [[self.args.ned_features_default_val[0]] * self.args.ned_features_size[0] for _ in
range(new_sentence_length)]
if 'type_prob' in self.args.ned_features and field_name != 'answer':
tokens_type_probs = [[self.args.ned_features_default_val[1]] * self.args.ned_features_size[1] for _ in
range(new_sentence_length)]
if self.args.do_ned and self.args.ned_retrieve_method != 'bootleg' and field_name not in self.no_feature_fields:
if 'type_id' in self.args.ned_features:
tokens_type_ids = self.find_type_ids(new_tokens, answer)
if 'type_prob' in self.args.ned_features:
tokens_type_probs = self.find_type_probs(new_tokens, self.args.ned_features_default_val[1],
self.args.ned_features_size[1])
if self.args.verbose and self.args.do_ned:
print()
print(
*[f'token: {token}\ttype: {token_type}' for token, token_type in zip(new_tokens, tokens_type_ids)],
sep='\n')
zip_list = []
if tokens_type_ids:
assert len(tokens_type_ids) == new_sentence_length
zip_list.append(tokens_type_ids)
if tokens_type_probs:
assert len(tokens_type_probs) == new_sentence_length
zip_list.append(tokens_type_probs)
features = [Feature(*tup) for tup in zip(*zip_list)]
sentence_plus_types = ''
if self.args.do_ned and self.args.add_types_to_text != 'no' and len(features):
sentence_plus_types = self.create_sentence_plus_types_tokens(new_sentence, features, self.args.add_types_to_text)
return new_sentence, features, sentence_plus_types
def get_splits(self, root, **kwargs):
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
@ -518,7 +595,7 @@ class Paraphrase(NaturalSeq2Seq):
sentence, reverse_map = input_heuristics(sentence, thingtalk=thingtalk, is_cased=True)
# this task especially needs example ids to be unique
if example_id in self.reverse_maps:
while example_id in self.reverse_maps:
example_id += '.'
self.reverse_maps[example_id] = reverse_map

View File

@ -317,7 +317,7 @@ def unmask_special_tokens(string: str, exceptions: list):
def detokenize(string: str):
string, exceptions = mask_special_tokens(string)
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":", "-"]
for t in tokens:
string = string.replace(' ' + t, t)
string = string.replace("( ", "(")
@ -335,8 +335,9 @@ def tokenize(string: str):
string = string.replace("(", "( ")
string = string.replace('gonna', 'gon na')
string = string.replace('wanna', 'wan na')
string = re.sub('\s+', ' ', string)
string = unmask_special_tokens(string, exceptions)
string = re.sub('([A-Za-z:_.]+_[0-9]+)-', r'\1 - ', string) # add space before and after hyphen, e.g. "NUMBER_0-hour"
string = re.sub('\s+', ' ', string) # remove duplicate spaces
return string.strip()

View File

@ -67,11 +67,12 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
batch_size = len(batch.example_id)
batch_prediction = [[] for _ in range(batch_size)]
batch_confidence_features = [[] for _ in range(batch_size)]
batch_example_ids = batch.example_id
example_ids += batch.example_id
example_ids += batch_example_ids
if not output_predictions_only:
batch_answer = numericalizer.reverse(batch.answer.value.data)
batch_answer = [task.postprocess_prediction(example_ids[i], batch_answer[i]) for i in range(len(batch_answer))]
batch_answer = [task.postprocess_prediction(batch_example_ids[i], batch_answer[i]) for i in range(len(batch_answer))]
answers += batch_answer
batch_context = numericalizer.reverse(batch.context.value.data)
contexts += batch_context
@ -99,7 +100,7 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
partial_batch_prediction = numericalizer.reverse(raw_partial_batch_prediction)
# post-process predictions
for i in range(len(partial_batch_prediction)):
partial_batch_prediction[i] = task.postprocess_prediction(example_ids[(i//args.num_outputs[hyperparameter_idx]) % batch_size], partial_batch_prediction[i])
partial_batch_prediction[i] = task.postprocess_prediction(batch_example_ids[(i//args.num_outputs[hyperparameter_idx]) % batch_size], partial_batch_prediction[i])
# put them into the right array
for i in range(len(partial_batch_prediction)):
batch_prediction[(i//args.num_outputs[hyperparameter_idx]) % batch_size].append(partial_batch_prediction[i])