simplified generation code using GPTseq2seq
This commit is contained in:
parent
03e09eddc5
commit
e7e6e3a1c4
|
@ -1,16 +1,57 @@
|
|||
from typing import List
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
import torch
|
||||
|
||||
class GPT2Seq2Seq(GPT2LMHeadModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.sep_token = 50258
|
||||
self.end_token = 50259
|
||||
self.sep_token = 50258
|
||||
self.pad_token = 50257
|
||||
|
||||
|
||||
def pad_to_max_length(self, input_sequences: List[List[int]]):
|
||||
"""
|
||||
Adds pad tokens before the sep_token
|
||||
"""
|
||||
max_length = len(input_sequences[0]) # input is sorted by length
|
||||
copy_input_sequences = []
|
||||
for i in range(len(input_sequences)):
|
||||
sep_token_index = input_sequences[i].index(self.sep_token)
|
||||
copy_input_sequences.append(input_sequences[i][:sep_token_index] + \
|
||||
[self.pad_token]*(max_length-len(input_sequences[i])) +\
|
||||
input_sequences[i][sep_token_index:])
|
||||
|
||||
return copy_input_sequences
|
||||
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
""" repetition penalty from CTRL (https://arxiv.org/abs/1909.05858), but much faster on GPU
|
||||
"""
|
||||
if repetition_penalty == 1.0:
|
||||
return lprobs
|
||||
m = torch.scatter(input=torch.zeros_like(lprobs), dim=1, index=prev_output_tokens, value=1)
|
||||
m[:self.sep_token] = 0
|
||||
m[:self.pad_token] = 0
|
||||
# logger.info('m = ', m.shape)
|
||||
need_change = m * lprobs
|
||||
need_divide = need_change > 0
|
||||
need_multiply = need_change < 0
|
||||
lprobs = need_divide * lprobs / repetition_penalty + need_multiply * lprobs * repetition_penalty + (1-m) * lprobs
|
||||
|
||||
# old, slow implementation
|
||||
# if repetition_penalty != 1.0:
|
||||
# for i in range(context.shape[0]):
|
||||
# for previous_token in set(generated[i].tolist()):
|
||||
# if lprobs[i, previous_token] > 0:
|
||||
# lprobs[i, previous_token] /= repetition_penalty
|
||||
# else:
|
||||
# lprobs[i, previous_token] *= repetition_penalty
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
sep_token_position = (input_ids==self.sep_token).to(torch.long)
|
||||
assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token'
|
||||
assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s' % str(sep_token_position)
|
||||
token_type_ids = torch.cumsum(sep_token_position, dim=1) - sep_token_position
|
||||
attention_mask = (input_ids!=self.pad_token).to(torch.long) # 0 means mask, 1 means no mask
|
||||
position_ids = (torch.cumsum(attention_mask, dim=1)-1)*(1-token_type_ids)+(torch.cumsum(token_type_ids, dim=1)-1)*token_type_ids
|
|
@ -20,10 +20,9 @@ def chunks(lst, n):
|
|||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||
):
|
||||
fout = Path(out_file).open("w")
|
||||
# b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-mw6/checkpointepoch=1.ckpt')
|
||||
# b.model.save_pretrained('./workdir/models/bart-large-mw6/')
|
||||
# b.tokenizer.save_pretrained('./workdir/models/bart-large-mw6/')
|
||||
# b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-2to1/checkpointcheckpoint_ckpt_epoch_1.ckpt')
|
||||
# b.model.save_pretrained('./workdir/models/bart-large-2to1/')
|
||||
# b.tokenizer.save_pretrained('./workdir/models/bart-large-2to1/')
|
||||
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
@ -32,6 +31,7 @@ def generate_summaries(
|
|||
max_length = 140
|
||||
min_length = 1
|
||||
|
||||
fout = Path(out_file).open("w")
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||
# bad = ['which', 'Which', 'restaurant', 'restaurants']
|
||||
|
@ -39,7 +39,7 @@ def generate_summaries(
|
|||
summaries = model.generate(
|
||||
input_ids=dct["input_ids"].to(device),
|
||||
attention_mask=dct["attention_mask"].to(device),
|
||||
num_beams=1,
|
||||
num_beams=16,
|
||||
do_sample=False,
|
||||
temperature=1,
|
||||
length_penalty=1,
|
||||
|
@ -48,7 +48,7 @@ def generate_summaries(
|
|||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
num_return_sequences=1
|
||||
num_return_sequences=4
|
||||
# bad_words_ids=bad
|
||||
)
|
||||
# print(bad)
|
||||
|
|
|
@ -38,16 +38,15 @@ except RuntimeError:
|
|||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, BartConfig
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \
|
||||
top_k_top_p_filtering, SpecialTokenMap, remove_thingtalk_quotes
|
||||
SpecialTokenMap, remove_thingtalk_quotes
|
||||
from .metrics import computeBLEU
|
||||
# from .models.common import BeamHypotheses
|
||||
from .GPT2Seq2Seq import GPT2Seq2Seq
|
||||
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
|
@ -55,159 +54,16 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
|
|||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||
MAX_LENGTH = int(1000) # Hardcoded max length to avoid infinite loop
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'gpt2': (GPT2Seq2Seq, GPT2Tokenizer),
|
||||
'bart': (BartForConditionalGeneration, BartTokenizer)
|
||||
}
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_id, pad_token_id):
|
||||
""" repetition penalty from CTRL (https://arxiv.org/abs/1909.05858), but much faster on GPU
|
||||
we penalize only the tokens that appear in the context, not in the generated text
|
||||
"""
|
||||
if repetition_penalty == 1.0:
|
||||
return logits
|
||||
m = torch.scatter(input=torch.zeros_like(logits), dim=1, index=context, value=1)
|
||||
m[:prompt_token_id] = 0
|
||||
m[:pad_token_id] = 0
|
||||
# logger.info('m = ', m.shape)
|
||||
need_change = m * logits
|
||||
need_divide = need_change > 0
|
||||
need_multiply = need_change < 0
|
||||
logits = need_divide * logits / repetition_penalty + need_multiply * logits * repetition_penalty + (1-m) * logits
|
||||
|
||||
# Old, slow implementation
|
||||
# if repetition_penalty != 1.0:
|
||||
# for i in range(context.shape[0]):
|
||||
# for _ in set(generated[i].tolist()):
|
||||
# if logits[i, _] > 0:
|
||||
# logits[i, _] /= repetition_penalty
|
||||
# else:
|
||||
# logits[i, _] *= repetition_penalty
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, length, min_output_length, context, num_samples,
|
||||
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, device='cpu',
|
||||
stop_token_ids=None, pad_token_id=None, supports_past=False, prompt_token_id=None, segment_token_ids=None,
|
||||
start_reverse_position_ids=None, output_form=None):
|
||||
"""
|
||||
Generates sequence of tokens for the batch of input contexts.
|
||||
Inputs:
|
||||
context: a list of token_ids, sorted by length from longest to shortest
|
||||
num_samples: the number of sequences to output for each input context
|
||||
length: The maximum length of generation in addition to the original sentence's length
|
||||
stop_token_ids: generation of each sequence will stop if we generate any of these tokens
|
||||
supports_past: set to True if the model accepts the 'past' input for more efficient generation. For example, GPT-2/Transfo-XL/XLNet/CTRL do
|
||||
segment_token_ids: a list of two integers that indicate the tokens we should use for each of the two segments
|
||||
"""
|
||||
max_length = len(context[0]) # context is sorted by length from longest to shortest
|
||||
min_length = len(context[-1])
|
||||
|
||||
# should not change the elements of context since it will change them outside this function as well.
|
||||
padded_context = []
|
||||
for i in range(len(context)):
|
||||
padded_context.append(context[i] + [pad_token_id] * (max_length-len(context[i]))) # pad to max_length
|
||||
|
||||
next_index = min_length
|
||||
length = max_length + (max_length - min_length) + length # generate till max_length, then generate another max_length+length tokens
|
||||
max_index = length + next_index
|
||||
|
||||
segment_ids = []
|
||||
position_ids = []
|
||||
for i in range(len(context)):
|
||||
prompt_token_position = context[i].index(prompt_token_id)
|
||||
p = list(range(prompt_token_position+1))
|
||||
segment_ids.append([segment_token_ids[0]]*len(p) + [segment_token_ids[1]]*(max_index - len(p)))
|
||||
if start_reverse_position_ids is None:
|
||||
position_ids.append(p + list(range(max_index - len(p))))
|
||||
else:
|
||||
position_ids.append(p + list(reversed(range(start_reverse_position_ids+len(p)))) + [0]*(max_index-start_reverse_position_ids-2*len(p)))
|
||||
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.repeat(num_samples, 1)
|
||||
segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device)
|
||||
segment_ids = segment_ids.repeat(num_samples, 1)
|
||||
|
||||
# logger.info('context = ', context)
|
||||
# logger.info('position_ids = ', position_ids)
|
||||
# logger.info('segment_ids = ', segment_ids)
|
||||
|
||||
context = torch.tensor(padded_context, dtype=torch.long, device=device)
|
||||
context = context.repeat(num_samples, 1)
|
||||
generated = context[:, :next_index]
|
||||
generated_length = torch.zeros((context.shape[0], 1), dtype=torch.long, device=device)
|
||||
should_finish = None
|
||||
generated_logits = None
|
||||
past = None
|
||||
next_token = None
|
||||
with torch.no_grad():
|
||||
for _ in range(length):
|
||||
inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]}
|
||||
if supports_past:
|
||||
inputs['past'] = past
|
||||
if past is not None:
|
||||
inputs['input_ids'] = next_token
|
||||
inputs['position_ids'] = position_ids[:, next_index-1]
|
||||
inputs['token_type_ids'] = segment_ids[:, next_index-1]
|
||||
|
||||
outputs = model(**inputs)
|
||||
original_next_token_logits = outputs[0][:, -1, :]
|
||||
next_token_logits = original_next_token_logits / (temperature if temperature > 0 else 1.)
|
||||
past = outputs[1]
|
||||
|
||||
next_token_logits = apply_repetition_penalty(next_token_logits, context, repetition_penalty,
|
||||
prompt_token_id=prompt_token_id, pad_token_id=pad_token_id)
|
||||
|
||||
if next_index < context.shape[1]:
|
||||
m = (context[:, next_index:next_index+1] != pad_token_id).long() # m==0 is where next_token should be kept
|
||||
else:
|
||||
m = torch.zeros(1, device=device)
|
||||
|
||||
# prevent stop_tokens if generated_length < min_output_length
|
||||
should_remove_stop_tokens = (generated_length < min_output_length)
|
||||
next_token_logits[:, stop_token_ids] = next_token_logits[:, stop_token_ids].masked_fill(should_remove_stop_tokens, -float('Inf'))
|
||||
# logger.info('after ', next_token_logits[:, stop_token_ids])
|
||||
generated_length = generated_length + (1-m)
|
||||
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
|
||||
if temperature == 0: # greedy sampling:
|
||||
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
|
||||
else:
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
|
||||
if output_form == 'logprob':
|
||||
generated_token_logit = F.log_softmax(original_next_token_logits, dim=-1).gather(1, next_token)
|
||||
else:
|
||||
assert output_form == 'logit'
|
||||
generated_token_logit = original_next_token_logits.gather(1, next_token)
|
||||
|
||||
# throw away the tokens that we already have from the context
|
||||
if next_index < context.shape[1]:
|
||||
next_token = m*context[:, next_index:next_index+1] + (1-m)*next_token
|
||||
generated_token_logit = (1-m)*generated_token_logit
|
||||
|
||||
for stop_token_id in stop_token_ids:
|
||||
if should_finish is None:
|
||||
should_finish = ((next_token == stop_token_id) & (1-m).bool())
|
||||
else:
|
||||
should_finish = should_finish | ((next_token == stop_token_id) & (1-m).bool())
|
||||
next_index += 1
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
if generated_logits is None:
|
||||
generated_logits = generated_token_logit
|
||||
else:
|
||||
generated_logits = torch.cat((generated_logits, generated_token_logit), dim=1)
|
||||
if should_finish.all():
|
||||
break
|
||||
return generated, generated_logits
|
||||
|
||||
|
||||
special_pattern_mapping = [
|
||||
SpecialTokenMap('PHONE_NUMBER_([0-9]+)', ['888-8888', '777-8888']),
|
||||
SpecialTokenMap('NUMBER_([0-9]+)', ['2', '3'], [['2', 'two'], ['3', 'three']]),
|
||||
|
@ -225,7 +81,7 @@ special_pattern_mapping = [
|
|||
SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
|
||||
]
|
||||
|
||||
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, prompt_token,
|
||||
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token,
|
||||
skip_heuristics, is_cased):
|
||||
"""
|
||||
Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs
|
||||
|
@ -267,7 +123,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
|
||||
# logger.info('input_sequence = %s', input_sequence)
|
||||
reverse_maps.append(reverse_map)
|
||||
input_sequence += prompt_token
|
||||
input_sequence += sep_token
|
||||
prompt = '' # includes the first few tokens of the output
|
||||
if prompt_column is not None and len(row) > prompt_column:
|
||||
prompt = row[prompt_column]
|
||||
|
@ -278,6 +134,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
context_tokens = input_sequence_tokens + prompt_tokens
|
||||
if copy > 0:
|
||||
assert prompt == ''
|
||||
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
|
||||
all_input_sequences.append(input_sequence)
|
||||
all_input_sequence_lengths.append(len(input_sequence_tokens))
|
||||
|
@ -289,7 +146,6 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
|
||||
return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps
|
||||
|
||||
|
||||
def is_question(sentence: str):
|
||||
question_words = ['which', 'what', 'where', 'how', 'who', 'when', 'is', 'are', 'am', \
|
||||
'can', 'could', 'would', 'will', 'have', 'did', 'do', 'does', 'no is', 'yes is']
|
||||
|
@ -435,20 +291,16 @@ def parse_argv(parser):
|
|||
parser.add_argument("--metric_reduction", type=str, choices=['average', 'max'], default='average',
|
||||
help="How we should calculate metrics where there are multiple generations per example.")
|
||||
|
||||
# These can be used for improving the quality of the output
|
||||
parser.add_argument("--num_samples", type=int, default=1)
|
||||
parser.add_argument("--selection_criterion", type=str, choices=['none', 'average_logit', 'average_logprob', 'bleu'], default='none',
|
||||
help='Select one of --num_sample outputs that maximizes this criterion')
|
||||
|
||||
# These are generation hyperparameters. Each one can be a list of values in which case, we generate num_samples outputs for each set of hyperparameters.
|
||||
parser.add_argument("--start_reverse_position_ids", type=int, nargs='+', default=[None],
|
||||
help='If provided, position ids will be the number of tokens left in generation and will start from len(input) + args.start_reverse_position_ids')
|
||||
parser.add_argument("--temperature", type=float, nargs='+', default=[1.0],
|
||||
help="temperature of 0 implies greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, nargs='+', default=[1.0],
|
||||
help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument("--top_k", type=int, nargs='+', default=[0], help='0 disables top-k filtering')
|
||||
parser.add_argument("--top_p", type=float, nargs='+', default=[0.9], help='1.0 disables top-p filtering')
|
||||
parser.add_argument("--num_beams", type=int, nargs='+', default=[1], help='1 disables beam seach')
|
||||
|
||||
parser.add_argument("--copy", type=int, default=0,
|
||||
help='Number of tokens that will be copied at the beginning of generation. Helps preserve the original meaning of the input sequence.')
|
||||
|
@ -456,7 +308,7 @@ def parse_argv(parser):
|
|||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--prompt_token', type=str, default='<paraphrase>',
|
||||
parser.add_argument('--sep_token', type=str, default='<paraphrase>',
|
||||
help="Token after which text generation starts. We add this to the end of all inputs.")
|
||||
parser.add_argument('--stop_tokens', type=str, nargs='+', default=['</paraphrase>'],
|
||||
help="Token at which text generation is stopped. The first element of the list is used as segment id as well.")
|
||||
|
@ -466,7 +318,7 @@ def parse_argv(parser):
|
|||
def main(args):
|
||||
if args.prompt_column is not None and args.copy is not None and args.copy != 0:
|
||||
raise ValueError('Cannot copy from the input and use prompt at the same time. Disable either --copy or --prompt_column.')
|
||||
hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'start_reverse_position_ids']
|
||||
hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']
|
||||
max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters])
|
||||
valid_len = [1, max_hyperparameter_len]
|
||||
for h in hyperparameters:
|
||||
|
@ -534,7 +386,7 @@ def run_generation(args):
|
|||
|
||||
|
||||
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
||||
prompt_token_id = tokenizer.convert_tokens_to_ids(args.prompt_token)
|
||||
sep_token_id = tokenizer.convert_tokens_to_ids(args.sep_token)
|
||||
if pad_token_id is None:
|
||||
logger.error('Your tokenizer does not have a padding token')
|
||||
|
||||
|
@ -543,7 +395,7 @@ def run_generation(args):
|
|||
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
|
||||
copy=args.copy,
|
||||
thingtalk_column=args.thingtalk_column,
|
||||
prompt_token=args.prompt_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased)
|
||||
sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased)
|
||||
|
||||
|
||||
# sort contexts based on their context length so that less generated tokens are thrown away and generation can be done faster
|
||||
|
@ -561,38 +413,32 @@ def run_generation(args):
|
|||
batch_context_tokens = all_context_tokens[batch_slice[0]: batch_slice[1]]
|
||||
batch_reverse_maps = reverse_maps[batch_slice[0]: batch_slice[1]]
|
||||
|
||||
batch_context_tensor = input_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device)
|
||||
|
||||
batch_outputs = [[] for _ in range(batch_size)]
|
||||
batch_criterion = [[] for _ in range(batch_size)]
|
||||
for hyperparameter_idx in range(len(args.temperature)):
|
||||
out, out_logits = sample_sequence(
|
||||
model=model,
|
||||
context=batch_context_tokens,
|
||||
num_samples=args.num_samples,
|
||||
length=args.length,
|
||||
min_output_length=args.min_output_length,
|
||||
temperature=args.temperature[hyperparameter_idx],
|
||||
top_k=args.top_k[hyperparameter_idx],
|
||||
top_p=args.top_p[hyperparameter_idx],
|
||||
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
|
||||
device=args.device,
|
||||
stop_token_ids=stop_token_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
supports_past=args.model_type in ['gpt2'],
|
||||
prompt_token_id=prompt_token_id,
|
||||
segment_token_ids=[tokenizer.convert_tokens_to_ids(args.prompt_token), tokenizer.convert_tokens_to_ids(args.stop_tokens[0])] if args.model_type=='gpt2' else [0, 1],
|
||||
start_reverse_position_ids=args.start_reverse_position_ids[hyperparameter_idx],
|
||||
output_form='logit' if args.selection_criterion=='average_logit' else 'logprob'
|
||||
)
|
||||
out = model.generate(input_ids=batch_context_tensor,
|
||||
min_length=args.min_output_length,
|
||||
max_length=batch_context_tensor.shape[1]+args.length,
|
||||
num_beams=args.num_beams[hyperparameter_idx],
|
||||
top_k=args.top_k[hyperparameter_idx],
|
||||
top_p=args.top_p[hyperparameter_idx],
|
||||
early_stopping=True,
|
||||
num_return_sequences=args.num_samples,
|
||||
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
|
||||
do_sample=args.temperature[hyperparameter_idx]!=0,
|
||||
temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, # if temperature==0, we do not sample
|
||||
eos_token_id=stop_token_ids[0],
|
||||
pad_token_id=pad_token_id
|
||||
)
|
||||
|
||||
out = out[:, :].tolist()
|
||||
out_logits = out_logits[:, :].tolist()
|
||||
for i, o in enumerate(out):
|
||||
o_logits = out_logits[i]
|
||||
# logger.info('all output tokens: %s', o)
|
||||
# logger.info('all output tokens: %s', str(o))
|
||||
# logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)))
|
||||
o = o[batch_input_sequence_lengths[i % batch_size]:]
|
||||
# logger.info('original context tokens: %s', str(batch_context_tokens[i % batch_size]))
|
||||
# logger.info('original input sequence: %s', str(batch_input_sequences[i % batch_size]))
|
||||
o = [x for x in o if x!=pad_token_id][batch_input_sequence_lengths[(i//args.num_samples) % batch_size]:]
|
||||
# logger.info('original context tokens: %s', str(batch_context_tokens[(i//args.num_samples) % batch_size]))
|
||||
# logger.info('original input sequence: %s', str(batch_input_sequences[(i//args.num_samples) % batch_size]))
|
||||
|
||||
if args.stop_tokens is not None:
|
||||
min_index = len(o)
|
||||
|
@ -605,9 +451,6 @@ def run_generation(args):
|
|||
if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids('?'):
|
||||
# always include the question mark
|
||||
min_index = min_index + 1
|
||||
if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids(args.stop_tokens[0]):
|
||||
# include </paraphrase> in logit calculation
|
||||
o_logits = o_logits[:len(o_logits)-(len(o)-min_index-1)]
|
||||
o = o[:min_index]
|
||||
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)
|
||||
|
@ -617,34 +460,12 @@ def run_generation(args):
|
|||
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
|
||||
text = text.strip()
|
||||
if not args.skip_heuristics:
|
||||
text = output_heuristics(text, batch_reverse_maps[i % batch_size])
|
||||
batch_outputs[i % batch_size].append(text)
|
||||
text = output_heuristics(text, batch_reverse_maps[(i//args.num_samples) % batch_size])
|
||||
batch_outputs[(i//args.num_samples) % batch_size].append(text)
|
||||
|
||||
if args.selection_criterion == 'bleu':
|
||||
# computeBLEU always converts to lower case first, so do not worry about lower/upper case here
|
||||
criterion = computeBLEU([text], [[batch_input_sequences[i % batch_size]]])
|
||||
else:
|
||||
criterion = np.mean(o_logits)
|
||||
batch_criterion[i % batch_size].append(criterion)
|
||||
# logger.info('generated tokens: %s', str(o))
|
||||
# logger.info('o_logits = %s', str(o_logits))
|
||||
# logger.info('generated cirterion: %.2f', criterion)
|
||||
# logger.info('text = %s', text)
|
||||
# logger.info('-'*10)
|
||||
all_outputs.extend(batch_outputs)
|
||||
|
||||
|
||||
if args.selection_criterion == 'none':
|
||||
all_outputs.extend(batch_outputs)
|
||||
else:
|
||||
for idx, example in enumerate(batch_outputs):
|
||||
logger.info('input sequence: %s', str(batch_input_sequences[idx % batch_size]))
|
||||
c, example = tuple(zip(*sorted(list(zip(batch_criterion[idx], example)), reverse=True)))
|
||||
logger.info(example)
|
||||
logger.info(c)
|
||||
logger.info('-'*10)
|
||||
selection = example[0]
|
||||
all_outputs.append([selection])
|
||||
|
||||
# sort the results back to their original order
|
||||
_, all_outputs = tuple(zip(*sorted(list(zip(original_order, all_outputs)))))
|
||||
|
||||
|
@ -653,8 +474,8 @@ def run_generation(args):
|
|||
if args.output_file is not None:
|
||||
with open(args.output_file, 'w') as output_file:
|
||||
if args.output_file is not None:
|
||||
for _ in all_outputs:
|
||||
for text in _:
|
||||
for output in all_outputs:
|
||||
for text in output:
|
||||
output_file.write(text + '\n')
|
||||
else:
|
||||
print(json.dumps(all_outputs, indent=2))
|
||||
|
|
Loading…
Reference in New Issue