Added more logging

This commit is contained in:
Sina 2020-05-01 23:09:50 -07:00
parent 68ffaa1561
commit f886b6f409
2 changed files with 18 additions and 9 deletions

View File

@ -1,3 +1,4 @@
import logging
from typing import List
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import torch
@ -10,7 +11,9 @@ class GPT2Seq2Seq(GPT2LMHeadModel):
self.end_token_id = end_token_id
self.sep_token_id = sep_token_id
self.pad_token_id = pad_token_id
logging.info('end_token_id = %s', self.end_token_id)
logging.info('sep_token_id = %s', self.sep_token_id)
logging.info('pad_token_id = %s', self.pad_token_id)
def pad_to_max_length(self, input_sequences: List[List[int]]):
"""

View File

@ -81,7 +81,7 @@ special_pattern_mapping = [
SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this
]
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token,
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token_id,
skip_heuristics, is_cased, model_type):
"""
Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs
@ -134,8 +134,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
if copy > 0:
assert len(prompt_tokens) == 0
prompt_tokens = context_tokens[0 : min(copy, len(context_tokens)-1)] # -1 since we should not copy sep_token
context_tokens = input_sequence_tokens + [tokenizer.convert_tokens_to_ids(sep_token)] + prompt_tokens
prompt_tokens = input_sequence_tokens[0 : min(copy, len(input_sequence_tokens)-1)]
context_tokens = input_sequence_tokens + [sep_token_id] + prompt_tokens
all_input_sequences.append(input_sequence)
all_input_sequence_lengths.append(len(input_sequence_tokens))
all_context_tokens.append(context_tokens)
@ -378,13 +378,15 @@ def run_generation(args):
model.to(args.device)
model.eval()
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token'])
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
if pad_token_id is None:
logger.error('Your tokenizer does not have a padding token')
if args.model_type == 'gpt2':
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']),
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']),
model.set_token_ids(end_token_id=end_token_id,
sep_token_id=sep_token_id,
pad_token_id=pad_token_id)
logger.info(args)
@ -394,7 +396,7 @@ def run_generation(args):
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
copy=args.copy,
thingtalk_column=args.thingtalk_column,
sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
sep_token_id=sep_token_id, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
model_type=args.model_type)
@ -404,9 +406,10 @@ def run_generation(args):
all_outputs = []
stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens]
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
batch_idx = 0
for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"):
logging.info('') # to make kubectl properly print tqdm progress bar
batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens)))
batch_size = batch_slice[1] - batch_slice[0]
batch_input_sequences = all_input_sequences[batch_slice[0]: batch_slice[1]]
@ -473,6 +476,9 @@ def run_generation(args):
batch_outputs[(i//args.num_samples) % batch_size].append(text)
all_outputs.extend(batch_outputs)
if batch_idx < 1:
logger.info('First batch output: %s', str(all_outputs))
batch_idx += 1
# sort the results back to their original order