removed unused models in generation scripts
This commit is contained in:
parent
af8d097485
commit
03e09eddc5
|
@ -1,5 +1,4 @@
|
|||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import torch
|
||||
|
||||
class GPT2Seq2Seq(GPT2LMHeadModel):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
|
||||
""" Conditional text generation with GPT-2/BART
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
|
@ -40,16 +40,10 @@ except RuntimeError:
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig
|
||||
from transformers import GPT2Config, BartConfig
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||
from transformers import CTRLLMHeadModel, CTRLTokenizer
|
||||
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||
from transformers import BertForMaskedLM, BertTokenizer
|
||||
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \
|
||||
top_k_top_p_filtering, SpecialTokenMap, remove_thingtalk_quotes
|
||||
from .metrics import computeBLEU
|
||||
|
@ -63,16 +57,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig, BertConfig)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
|
||||
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||||
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
|
||||
'bert': (BertForMaskedLM, BertTokenizer),
|
||||
'bart': (BartForConditionalGeneration, BartTokenizer)
|
||||
}
|
||||
|
||||
|
||||
|
@ -103,8 +92,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i
|
|||
|
||||
|
||||
def sample_sequence(model, length, min_output_length, context, num_samples,
|
||||
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0,
|
||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu',
|
||||
temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, device='cpu',
|
||||
stop_token_ids=None, pad_token_id=None, supports_past=False, prompt_token_id=None, segment_token_ids=None,
|
||||
start_reverse_position_ids=None, output_form=None):
|
||||
"""
|
||||
|
@ -158,30 +146,8 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
|
|||
past = None
|
||||
next_token = None
|
||||
with torch.no_grad():
|
||||
# rep_penalty = np.random.random(length) < 0.1
|
||||
# original_rep_penalty = repetition_penalty
|
||||
# logger.info('rep_penalty = ', rep_penalty)
|
||||
for _ in range(length):
|
||||
inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]}
|
||||
if is_xlnet:
|
||||
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
||||
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
||||
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
||||
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
if is_xlm_mlm and xlm_mask_token:
|
||||
# XLM MLM models are direct models (predict same token, not next token)
|
||||
# => need one additional dummy token in the input (will be masked and guessed)
|
||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
|
||||
if supports_past:
|
||||
inputs['past'] = past
|
||||
if past is not None:
|
||||
|
@ -461,7 +427,6 @@ def parse_argv(parser):
|
|||
parser.add_argument('--thingtalk_column', type=int, default=None,
|
||||
help='The column in the input file which contains the ThingTalk program.')
|
||||
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.")
|
||||
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length')
|
||||
parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.')
|
||||
parser.add_argument("--skip_heuristics", action='store_true', help='If True, will not replace special word such as NUMBER_0 in the input.')
|
||||
|
@ -566,28 +531,7 @@ def run_generation(args):
|
|||
args.length = MAX_LENGTH # avoid infinite loop
|
||||
|
||||
logger.info(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7:
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
xlm_lang = None
|
||||
# XLM Language usage detailed in the issues #1414
|
||||
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
|
||||
and model.config.use_lang_emb:
|
||||
if args.xlm_lang:
|
||||
language = args.xlm_lang
|
||||
else:
|
||||
language = None
|
||||
while language not in tokenizer.lang2id.keys():
|
||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||
xlm_lang = tokenizer.lang2id[language]
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
xlm_mask_token = tokenizer.mask_token_id
|
||||
else:
|
||||
xlm_mask_token = None
|
||||
|
||||
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
||||
prompt_token_id = tokenizer.convert_tokens_to_ids(args.prompt_token)
|
||||
|
@ -630,14 +574,10 @@ def run_generation(args):
|
|||
top_k=args.top_k[hyperparameter_idx],
|
||||
top_p=args.top_p[hyperparameter_idx],
|
||||
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
is_xlm_mlm=is_xlm_mlm,
|
||||
xlm_mask_token=xlm_mask_token,
|
||||
xlm_lang=xlm_lang,
|
||||
device=args.device,
|
||||
stop_token_ids=stop_token_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
supports_past=args.model_type in ['gpt2', 'openai-gpt', 'transfo-xl', 'xlnet', 'ctrl'],
|
||||
supports_past=args.model_type in ['gpt2'],
|
||||
prompt_token_id=prompt_token_id,
|
||||
segment_token_ids=[tokenizer.convert_tokens_to_ids(args.prompt_token), tokenizer.convert_tokens_to_ids(args.stop_tokens[0])] if args.model_type=='gpt2' else [0, 1],
|
||||
start_reverse_position_ids=args.start_reverse_position_ids[hyperparameter_idx],
|
||||
|
|
Loading…
Reference in New Issue