From f886b6f4096bf8c6283e82ce3c48dda533c39e06 Mon Sep 17 00:00:00 2001 From: Sina Date: Fri, 1 May 2020 23:09:50 -0700 Subject: [PATCH] Added more logging --- genienlp/GPT2Seq2Seq.py | 5 ++++- genienlp/run_generation.py | 22 ++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/genienlp/GPT2Seq2Seq.py b/genienlp/GPT2Seq2Seq.py index b7a39592..b52b7bdd 100644 --- a/genienlp/GPT2Seq2Seq.py +++ b/genienlp/GPT2Seq2Seq.py @@ -1,3 +1,4 @@ +import logging from typing import List from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer import torch @@ -10,7 +11,9 @@ class GPT2Seq2Seq(GPT2LMHeadModel): self.end_token_id = end_token_id self.sep_token_id = sep_token_id self.pad_token_id = pad_token_id - + logging.info('end_token_id = %s', self.end_token_id) + logging.info('sep_token_id = %s', self.sep_token_id) + logging.info('pad_token_id = %s', self.pad_token_id) def pad_to_max_length(self, input_sequences: List[List[int]]): """ diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index 2537c69b..8ea282a8 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -81,7 +81,7 @@ special_pattern_mapping = [ SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this ] -def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token, +def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token_id, skip_heuristics, is_cased, model_type): """ Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs @@ -134,8 +134,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) if copy > 0: assert len(prompt_tokens) == 0 - prompt_tokens = context_tokens[0 : min(copy, len(context_tokens)-1)] # -1 since we should not copy sep_token - context_tokens = input_sequence_tokens + [tokenizer.convert_tokens_to_ids(sep_token)] + prompt_tokens + prompt_tokens = input_sequence_tokens[0 : min(copy, len(input_sequence_tokens)-1)] + context_tokens = input_sequence_tokens + [sep_token_id] + prompt_tokens all_input_sequences.append(input_sequence) all_input_sequence_lengths.append(len(input_sequence_tokens)) all_context_tokens.append(context_tokens) @@ -378,13 +378,15 @@ def run_generation(args): model.to(args.device) model.eval() + end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token']) + sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']) pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) if pad_token_id is None: logger.error('Your tokenizer does not have a padding token') if args.model_type == 'gpt2': - model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']), - sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']), + model.set_token_ids(end_token_id=end_token_id, + sep_token_id=sep_token_id, pad_token_id=pad_token_id) logger.info(args) @@ -394,7 +396,7 @@ def run_generation(args): input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column, copy=args.copy, thingtalk_column=args.thingtalk_column, - sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased, + sep_token_id=sep_token_id, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased, model_type=args.model_type) @@ -404,9 +406,10 @@ def run_generation(args): all_outputs = [] stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens] - end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token']) - + + batch_idx = 0 for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"): + logging.info('') # to make kubectl properly print tqdm progress bar batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens))) batch_size = batch_slice[1] - batch_slice[0] batch_input_sequences = all_input_sequences[batch_slice[0]: batch_slice[1]] @@ -473,6 +476,9 @@ def run_generation(args): batch_outputs[(i//args.num_samples) % batch_size].append(text) all_outputs.extend(batch_outputs) + if batch_idx < 1: + logger.info('First batch output: %s', str(all_outputs)) + batch_idx += 1 # sort the results back to their original order