diff --git a/genienlp/GPT2seq2seq.py b/genienlp/GPT2seq2seq.py index 52e84b2b..d87c85b2 100644 --- a/genienlp/GPT2seq2seq.py +++ b/genienlp/GPT2seq2seq.py @@ -1,5 +1,4 @@ from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer -from torch.nn import CrossEntropyLoss import torch class GPT2Seq2Seq(GPT2LMHeadModel): diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index 565d0c04..cae37d03 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -15,7 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) +""" Conditional text generation with GPT-2/BART """ from __future__ import absolute_import, division, print_function, unicode_literals @@ -40,16 +40,10 @@ except RuntimeError: import torch import torch.nn.functional as F -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig +from transformers import GPT2Config, BartConfig from transformers import GPT2LMHeadModel, GPT2Tokenizer -from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer -from transformers import XLNetLMHeadModel, XLNetTokenizer -from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer -from transformers import CTRLLMHeadModel, CTRLTokenizer -from transformers import XLMWithLMHeadModel, XLMTokenizer -from transformers import BertForMaskedLM, BertTokenizer - +from transformers import BartForConditionalGeneration, BartTokenizer from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \ top_k_top_p_filtering, SpecialTokenMap, remove_thingtalk_quotes from .metrics import computeBLEU @@ -63,16 +57,11 @@ logger = logging.getLogger(__name__) MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ()) MODEL_CLASSES = { 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), - 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), - 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), - 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), - 'xlm': (XLMWithLMHeadModel, XLMTokenizer), - 'bert': (BertForMaskedLM, BertTokenizer), + 'bart': (BartForConditionalGeneration, BartTokenizer) } @@ -103,8 +92,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i def sample_sequence(model, length, min_output_length, context, num_samples, - temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, - is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu', + temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, device='cpu', stop_token_ids=None, pad_token_id=None, supports_past=False, prompt_token_id=None, segment_token_ids=None, start_reverse_position_ids=None, output_form=None): """ @@ -158,30 +146,8 @@ def sample_sequence(model, length, min_output_length, context, num_samples, past = None next_token = None with torch.no_grad(): - # rep_penalty = np.random.random(length) < 0.1 - # original_rep_penalty = repetition_penalty - # logger.info('rep_penalty = ', rep_penalty) for _ in range(length): inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]} - if is_xlnet: - # XLNet is a direct (predict same token, not next token) and bi-directional model by default - # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring) - input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1) - perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device) - perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token - target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device) - target_mapping[0, 0, -1] = 1.0 # predict last token - inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} - - if is_xlm_mlm and xlm_mask_token: - # XLM MLM models are direct models (predict same token, not next token) - # => need one additional dummy token in the input (will be masked and guessed) - input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) - inputs = {'input_ids': input_ids} - - if xlm_lang is not None: - inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) - if supports_past: inputs['past'] = past if past is not None: @@ -461,7 +427,6 @@ def parse_argv(parser): parser.add_argument('--thingtalk_column', type=int, default=None, help='The column in the input file which contains the ThingTalk program.') parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.") - parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length') parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.') parser.add_argument("--skip_heuristics", action='store_true', help='If True, will not replace special word such as NUMBER_0 in the input.') @@ -566,28 +531,7 @@ def run_generation(args): args.length = MAX_LENGTH # avoid infinite loop logger.info(args) - if args.model_type in ["ctrl"]: - if args.temperature > 0.7: - logger.info('CTRL typically works better with lower temperatures (and lower top_k).') - xlm_lang = None - # XLM Language usage detailed in the issues #1414 - if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \ - and model.config.use_lang_emb: - if args.xlm_lang: - language = args.xlm_lang - else: - language = None - while language not in tokenizer.lang2id.keys(): - language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") - xlm_lang = tokenizer.lang2id[language] - - # XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence) - is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path - if is_xlm_mlm: - xlm_mask_token = tokenizer.mask_token_id - else: - xlm_mask_token = None pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) prompt_token_id = tokenizer.convert_tokens_to_ids(args.prompt_token) @@ -630,14 +574,10 @@ def run_generation(args): top_k=args.top_k[hyperparameter_idx], top_p=args.top_p[hyperparameter_idx], repetition_penalty=args.repetition_penalty[hyperparameter_idx], - is_xlnet=bool(args.model_type == "xlnet"), - is_xlm_mlm=is_xlm_mlm, - xlm_mask_token=xlm_mask_token, - xlm_lang=xlm_lang, device=args.device, stop_token_ids=stop_token_ids, pad_token_id=pad_token_id, - supports_past=args.model_type in ['gpt2', 'openai-gpt', 'transfo-xl', 'xlnet', 'ctrl'], + supports_past=args.model_type in ['gpt2'], prompt_token_id=prompt_token_id, segment_token_ids=[tokenizer.convert_tokens_to_ids(args.prompt_token), tokenizer.convert_tokens_to_ids(args.stop_tokens[0])] if args.model_type=='gpt2' else [0, 1], start_reverse_position_ids=args.start_reverse_position_ids[hyperparameter_idx],