removed unused models in generation scripts

This commit is contained in:
Sina 2020-04-26 00:48:52 -07:00
parent af8d097485
commit 03e09eddc5
2 changed files with 7 additions and 68 deletions

View File

@ -1,5 +1,4 @@
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from torch.nn import CrossEntropyLoss
import torch
class GPT2Seq2Seq(GPT2LMHeadModel):

View File

@ -15,7 +15,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
""" Conditional text generation with GPT-2/BART
"""
from __future__ import absolute_import, division, print_function, unicode_literals
@ -40,16 +40,10 @@ except RuntimeError:
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig
from transformers import GPT2Config, BartConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
from transformers import BertForMaskedLM, BertTokenizer
from transformers import BartForConditionalGeneration, BartTokenizer
from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \
top_k_top_p_filtering, SpecialTokenMap, remove_thingtalk_quotes
from .metrics import computeBLEU
@ -63,16 +57,11 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
'bert': (BertForMaskedLM, BertTokenizer),
'bart': (BartForConditionalGeneration, BartTokenizer)
}
@ -103,8 +92,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i
def sample_sequence(model, length, min_output_length, context, num_samples,
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu',
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, device='cpu',
stop_token_ids=None, pad_token_id=None, supports_past=False, prompt_token_id=None, segment_token_ids=None,
start_reverse_position_ids=None, output_form=None):
"""
@ -158,30 +146,8 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
past = None
next_token = None
with torch.no_grad():
# rep_penalty = np.random.random(length) < 0.1
# original_rep_penalty = repetition_penalty
# logger.info('rep_penalty = ', rep_penalty)
for _ in range(length):
inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if is_xlm_mlm and xlm_mask_token:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
if supports_past:
inputs['past'] = past
if past is not None:
@ -461,7 +427,6 @@ def parse_argv(parser):
parser.add_argument('--thingtalk_column', type=int, default=None,
help='The column in the input file which contains the ThingTalk program.')
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length')
parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.')
parser.add_argument("--skip_heuristics", action='store_true', help='If True, will not replace special word such as NUMBER_0 in the input.')
@ -566,28 +531,7 @@ def run_generation(args):
args.length = MAX_LENGTH # avoid infinite loop
logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
and model.config.use_lang_emb:
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
xlm_mask_token = tokenizer.mask_token_id
else:
xlm_mask_token = None
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
prompt_token_id = tokenizer.convert_tokens_to_ids(args.prompt_token)
@ -630,14 +574,10 @@ def run_generation(args):
top_k=args.top_k[hyperparameter_idx],
top_p=args.top_p[hyperparameter_idx],
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
is_xlnet=bool(args.model_type == "xlnet"),
is_xlm_mlm=is_xlm_mlm,
xlm_mask_token=xlm_mask_token,
xlm_lang=xlm_lang,
device=args.device,
stop_token_ids=stop_token_ids,
pad_token_id=pad_token_id,
supports_past=args.model_type in ['gpt2', 'openai-gpt', 'transfo-xl', 'xlnet', 'ctrl'],
supports_past=args.model_type in ['gpt2'],
prompt_token_id=prompt_token_id,
segment_token_ids=[tokenizer.convert_tokens_to_ids(args.prompt_token), tokenizer.convert_tokens_to_ids(args.stop_tokens[0])] if args.model_type=='gpt2' else [0, 1],
start_reverse_position_ids=args.start_reverse_position_ids[hyperparameter_idx],