Merge pull request #97 from stanford-oval/wip/fix-para
Fix paraphrasing bugs
This commit is contained in:
commit
69c5029f5f
|
@ -28,8 +28,8 @@ special_pattern_mapping = [
|
|||
['$13', 'thirteen dollars', '13 dollars', '$ 13', '$ 13.00', '13.00', '13']]),
|
||||
SpecialTokenMap('DURATION_([0-9]+)', ['5 weeks', '6 weeks'], [['5 weeks', 'five weeks'], ['6 weeks', 'six weeks']]),
|
||||
SpecialTokenMap('LOCATION_([0-9]+)', ['locatio1n', 'locatio2n'], [['locatio1n', 'locat1n'], ['locatio2n', 'locat2n']]),
|
||||
SpecialTokenMap('QUOTED_STRING_([0-9]+)', lambda x: 'Chinese', lambda x: ['Chinese', 'chinese', 'china']), # TODO change to be more general than cuisine
|
||||
SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
|
||||
# SpecialTokenMap('QUOTED_STRING_([0-9]+)', ['Chinese', 'Italian'], [['Chinese', 'chinese', 'china'], ['Italian', 'italian']]), # TODO change to be more general than cuisine
|
||||
# SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -4,16 +4,28 @@ import re
|
|||
from ...util import tokenize, lower_case, remove_thingtalk_quotes
|
||||
from ...data_utils.progbar import progress_bar
|
||||
|
||||
special_token_pattern = re.compile("(^|(?<= ))" + "[A-Z]+_[0-9]" + "($|(?= ))")
|
||||
def find_special_tokens(s: str):
|
||||
return list(sorted([a.group(0) for a in special_token_pattern.finditer(s)]))
|
||||
|
||||
|
||||
def is_subset(set1, set2):
|
||||
"""
|
||||
Returns True if set1 is a subset of or equal to set2
|
||||
"""
|
||||
return all([e in set2 for e in set1])
|
||||
|
||||
def passes_heuristic_checks(row, args):
|
||||
def passes_heuristic_checks(row, args, old_query=None):
|
||||
if 'QUOTED_STRING' in row[args.utterance_column] or (old_query is not None and 'QUOTED_STRING' in old_query):
|
||||
# remove quoted examples
|
||||
return False
|
||||
if old_query is not None:
|
||||
# check that all the special tokens in utterance after paraphrasing are the same as before
|
||||
if find_special_tokens(old_query) != find_special_tokens(row[args.utterance_column]):
|
||||
return False
|
||||
all_input_columns = ' '.join([row[c] for c in args.input_columns])
|
||||
input_special_tokens = set(re.findall('[A-Za-z:_.]+_[0-9]', all_input_columns))
|
||||
output_special_tokens = set(re.findall('[A-Za-z:_.]+_[0-9]', row[args.thingtalk_column]))
|
||||
input_special_tokens = set(find_special_tokens(all_input_columns))
|
||||
output_special_tokens = set(find_special_tokens(row[args.thingtalk_column]))
|
||||
if not is_subset(output_special_tokens, input_special_tokens):
|
||||
return False
|
||||
_, quote_values = remove_thingtalk_quotes(row[args.thingtalk_column])
|
||||
|
@ -110,6 +122,7 @@ def main(args):
|
|||
seen_examples = set()
|
||||
all_thrown_away_rows = []
|
||||
for row_idx, row in enumerate(progress_bar(reader, desc='Lines')):
|
||||
old_query = None
|
||||
output_rows = []
|
||||
thrown_away_rows = []
|
||||
if args.transformation == 'remove_thingtalk_quotes':
|
||||
|
@ -136,6 +149,7 @@ def main(args):
|
|||
row[args.utterance_column] = new_query
|
||||
output_rows.append(row)
|
||||
elif args.transformation == 'replace_queries':
|
||||
old_query = row[args.utterance_column]
|
||||
for idx, new_query in enumerate(new_queries[row_idx]):
|
||||
copy_row = row.copy()
|
||||
copy_row[args.utterance_column] = new_query
|
||||
|
@ -152,7 +166,7 @@ def main(args):
|
|||
for o in output_rows:
|
||||
output_row = ""
|
||||
if args.remove_with_heuristics:
|
||||
if not passes_heuristic_checks(o, args):
|
||||
if not passes_heuristic_checks(o, args, old_query=old_query):
|
||||
heuristic_count += 1
|
||||
continue
|
||||
if args.remove_duplicates:
|
||||
|
@ -170,7 +184,7 @@ def main(args):
|
|||
output_row += '\t'
|
||||
output_file.write(output_row + '\n')
|
||||
for o in thrown_away_rows:
|
||||
if not args.remove_with_heuristics or (args.remove_with_heuristics and passes_heuristic_checks(o, args)):
|
||||
if not args.remove_with_heuristics or (args.remove_with_heuristics and passes_heuristic_checks(o, args, old_query=old_query)):
|
||||
all_thrown_away_rows.append(o)
|
||||
|
||||
if args.thrown_away is not None:
|
||||
|
|
|
@ -484,6 +484,83 @@ class NaturalSeq2Seq(BaseAlmondTask):
|
|||
return Example.from_raw(self.name + '/' + example_id, context, question, answer,
|
||||
preprocess=self.preprocess_field, lower=False)
|
||||
|
||||
def preprocess_field(self, sentence, field_name=None, answer=None):
|
||||
if self.override_context is not None and field_name == 'context':
|
||||
pad_feature = get_pad_feature(self.args.ned_features, self.args.ned_features_default_val, self.args.ned_features_size)
|
||||
return self.override_context, [pad_feature] * len(self.override_context.split(' ')) if pad_feature else [], self.override_context
|
||||
if self.override_question is not None and field_name == 'question':
|
||||
pad_feature = get_pad_feature(self.args.ned_features, self.args.ned_features_default_val, self.args.ned_features_size)
|
||||
return self.override_question, [pad_feature] * len(self.override_question.split(' ')) if pad_feature else [], self.override_question
|
||||
if not sentence:
|
||||
return '', [], ''
|
||||
|
||||
tokens = sentence.split(' ')
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
new_tokens.append(token)
|
||||
tokens = new_tokens
|
||||
new_sentence = ' '.join(tokens)
|
||||
|
||||
if self._almond_detokenize_sentence:
|
||||
|
||||
# BERT tokenizers by default add whitespace around any CJK character
|
||||
# SPM-based tokenizers are trained on raw text and do better when recieve untokenized text
|
||||
# In genienlp we detokenize CJK characters and leave tokenization to the model's tokenizer
|
||||
# NOTE: input datasets for almond are usually pretokenized using genie-toolkit which
|
||||
# inserts whitespace around any CJK character. This detokenization ensures that SPM-based tokenizers
|
||||
# see the text without space between those characters
|
||||
new_sentence = detokenize_cjk_chars(new_sentence)
|
||||
tokens = new_sentence.split(' ')
|
||||
|
||||
new_sentence = ''
|
||||
for token in tokens:
|
||||
if token in (',', '.', '?', '!', ':', ')', ']', '}') or token.startswith("'"):
|
||||
new_sentence += token
|
||||
else:
|
||||
new_sentence += ' ' + token
|
||||
|
||||
new_sentence = new_sentence.strip()
|
||||
new_tokens = new_sentence.split(' ')
|
||||
new_sentence_length = len(new_tokens)
|
||||
|
||||
tokens_type_ids, tokens_type_probs = None, None
|
||||
|
||||
if 'type_id' in self.args.ned_features and field_name != 'answer':
|
||||
tokens_type_ids = [[self.args.ned_features_default_val[0]] * self.args.ned_features_size[0] for _ in
|
||||
range(new_sentence_length)]
|
||||
if 'type_prob' in self.args.ned_features and field_name != 'answer':
|
||||
tokens_type_probs = [[self.args.ned_features_default_val[1]] * self.args.ned_features_size[1] for _ in
|
||||
range(new_sentence_length)]
|
||||
|
||||
if self.args.do_ned and self.args.ned_retrieve_method != 'bootleg' and field_name not in self.no_feature_fields:
|
||||
if 'type_id' in self.args.ned_features:
|
||||
tokens_type_ids = self.find_type_ids(new_tokens, answer)
|
||||
if 'type_prob' in self.args.ned_features:
|
||||
tokens_type_probs = self.find_type_probs(new_tokens, self.args.ned_features_default_val[1],
|
||||
self.args.ned_features_size[1])
|
||||
|
||||
if self.args.verbose and self.args.do_ned:
|
||||
print()
|
||||
print(
|
||||
*[f'token: {token}\ttype: {token_type}' for token, token_type in zip(new_tokens, tokens_type_ids)],
|
||||
sep='\n')
|
||||
|
||||
zip_list = []
|
||||
if tokens_type_ids:
|
||||
assert len(tokens_type_ids) == new_sentence_length
|
||||
zip_list.append(tokens_type_ids)
|
||||
if tokens_type_probs:
|
||||
assert len(tokens_type_probs) == new_sentence_length
|
||||
zip_list.append(tokens_type_probs)
|
||||
|
||||
features = [Feature(*tup) for tup in zip(*zip_list)]
|
||||
|
||||
sentence_plus_types = ''
|
||||
if self.args.do_ned and self.args.add_types_to_text != 'no' and len(features):
|
||||
sentence_plus_types = self.create_sentence_plus_types_tokens(new_sentence, features, self.args.add_types_to_text)
|
||||
|
||||
return new_sentence, features, sentence_plus_types
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
|
||||
|
||||
|
@ -518,7 +595,7 @@ class Paraphrase(NaturalSeq2Seq):
|
|||
|
||||
sentence, reverse_map = input_heuristics(sentence, thingtalk=thingtalk, is_cased=True)
|
||||
# this task especially needs example ids to be unique
|
||||
if example_id in self.reverse_maps:
|
||||
while example_id in self.reverse_maps:
|
||||
example_id += '.'
|
||||
self.reverse_maps[example_id] = reverse_map
|
||||
|
||||
|
|
|
@ -317,7 +317,7 @@ def unmask_special_tokens(string: str, exceptions: list):
|
|||
|
||||
def detokenize(string: str):
|
||||
string, exceptions = mask_special_tokens(string)
|
||||
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":"]
|
||||
tokens = ["'d", "n't", "'ve", "'m", "'re", "'ll", ".", ",", "?", "!", "'s", ")", ":", "-"]
|
||||
for t in tokens:
|
||||
string = string.replace(' ' + t, t)
|
||||
string = string.replace("( ", "(")
|
||||
|
@ -335,8 +335,9 @@ def tokenize(string: str):
|
|||
string = string.replace("(", "( ")
|
||||
string = string.replace('gonna', 'gon na')
|
||||
string = string.replace('wanna', 'wan na')
|
||||
string = re.sub('\s+', ' ', string)
|
||||
string = unmask_special_tokens(string, exceptions)
|
||||
string = re.sub('([A-Za-z:_.]+_[0-9]+)-', r'\1 - ', string) # add space before and after hyphen, e.g. "NUMBER_0-hour"
|
||||
string = re.sub('\s+', ' ', string) # remove duplicate spaces
|
||||
return string.strip()
|
||||
|
||||
|
||||
|
|
|
@ -67,11 +67,12 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
|
|||
batch_size = len(batch.example_id)
|
||||
batch_prediction = [[] for _ in range(batch_size)]
|
||||
batch_confidence_features = [[] for _ in range(batch_size)]
|
||||
batch_example_ids = batch.example_id
|
||||
|
||||
example_ids += batch.example_id
|
||||
example_ids += batch_example_ids
|
||||
if not output_predictions_only:
|
||||
batch_answer = numericalizer.reverse(batch.answer.value.data)
|
||||
batch_answer = [task.postprocess_prediction(example_ids[i], batch_answer[i]) for i in range(len(batch_answer))]
|
||||
batch_answer = [task.postprocess_prediction(batch_example_ids[i], batch_answer[i]) for i in range(len(batch_answer))]
|
||||
answers += batch_answer
|
||||
batch_context = numericalizer.reverse(batch.context.value.data)
|
||||
contexts += batch_context
|
||||
|
@ -99,7 +100,7 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
|
|||
partial_batch_prediction = numericalizer.reverse(raw_partial_batch_prediction)
|
||||
# post-process predictions
|
||||
for i in range(len(partial_batch_prediction)):
|
||||
partial_batch_prediction[i] = task.postprocess_prediction(example_ids[(i//args.num_outputs[hyperparameter_idx]) % batch_size], partial_batch_prediction[i])
|
||||
partial_batch_prediction[i] = task.postprocess_prediction(batch_example_ids[(i//args.num_outputs[hyperparameter_idx]) % batch_size], partial_batch_prediction[i])
|
||||
# put them into the right array
|
||||
for i in range(len(partial_batch_prediction)):
|
||||
batch_prediction[(i//args.num_outputs[hyperparameter_idx]) % batch_size].append(partial_batch_prediction[i])
|
||||
|
|
Loading…
Reference in New Issue