Added more logging
This commit is contained in:
parent
68ffaa1561
commit
f886b6f409
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import List
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
import torch
|
||||
|
@ -10,7 +11,9 @@ class GPT2Seq2Seq(GPT2LMHeadModel):
|
|||
self.end_token_id = end_token_id
|
||||
self.sep_token_id = sep_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
logging.info('end_token_id = %s', self.end_token_id)
|
||||
logging.info('sep_token_id = %s', self.sep_token_id)
|
||||
logging.info('pad_token_id = %s', self.pad_token_id)
|
||||
|
||||
def pad_to_max_length(self, input_sequences: List[List[int]]):
|
||||
"""
|
||||
|
|
|
@ -81,7 +81,7 @@ special_pattern_mapping = [
|
|||
SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
|
||||
]
|
||||
|
||||
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token,
|
||||
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token_id,
|
||||
skip_heuristics, is_cased, model_type):
|
||||
"""
|
||||
Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs
|
||||
|
@ -134,8 +134,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
if copy > 0:
|
||||
assert len(prompt_tokens) == 0
|
||||
prompt_tokens = context_tokens[0 : min(copy, len(context_tokens)-1)] # -1 since we should not copy sep_token
|
||||
context_tokens = input_sequence_tokens + [tokenizer.convert_tokens_to_ids(sep_token)] + prompt_tokens
|
||||
prompt_tokens = input_sequence_tokens[0 : min(copy, len(input_sequence_tokens)-1)]
|
||||
context_tokens = input_sequence_tokens + [sep_token_id] + prompt_tokens
|
||||
all_input_sequences.append(input_sequence)
|
||||
all_input_sequence_lengths.append(len(input_sequence_tokens))
|
||||
all_context_tokens.append(context_tokens)
|
||||
|
@ -378,13 +378,15 @@ def run_generation(args):
|
|||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
|
||||
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token'])
|
||||
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
||||
if pad_token_id is None:
|
||||
logger.error('Your tokenizer does not have a padding token')
|
||||
|
||||
if args.model_type == 'gpt2':
|
||||
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']),
|
||||
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']),
|
||||
model.set_token_ids(end_token_id=end_token_id,
|
||||
sep_token_id=sep_token_id,
|
||||
pad_token_id=pad_token_id)
|
||||
|
||||
logger.info(args)
|
||||
|
@ -394,7 +396,7 @@ def run_generation(args):
|
|||
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
|
||||
copy=args.copy,
|
||||
thingtalk_column=args.thingtalk_column,
|
||||
sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
|
||||
sep_token_id=sep_token_id, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
|
||||
model_type=args.model_type)
|
||||
|
||||
|
||||
|
@ -404,9 +406,10 @@ def run_generation(args):
|
|||
all_outputs = []
|
||||
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens]
|
||||
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
|
||||
|
||||
|
||||
batch_idx = 0
|
||||
for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"):
|
||||
logging.info('') # to make kubectl properly print tqdm progress bar
|
||||
batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens)))
|
||||
batch_size = batch_slice[1] - batch_slice[0]
|
||||
batch_input_sequences = all_input_sequences[batch_slice[0]: batch_slice[1]]
|
||||
|
@ -473,6 +476,9 @@ def run_generation(args):
|
|||
batch_outputs[(i//args.num_samples) % batch_size].append(text)
|
||||
|
||||
all_outputs.extend(batch_outputs)
|
||||
if batch_idx < 1:
|
||||
logger.info('First batch output: %s', str(all_outputs))
|
||||
batch_idx += 1
|
||||
|
||||
|
||||
# sort the results back to their original order
|
||||
|
|
Loading…
Reference in New Issue