address mask_id issue for XLM-R

This commit is contained in:
mehrad 2020-03-24 16:05:01 -07:00
parent 1e2dbce017
commit 0acf64bdc3
1 changed files with 7 additions and 1 deletions

View File

@ -91,19 +91,25 @@ class TransformerNumericalizer(object):
self.unk_token = self._tokenizer.unk_token
self.pad_token = self._tokenizer.pad_token
self.mask_token = self._tokenizer.mask_token
self.cls_token = self._tokenizer.cls_token
self.init_id = self._tokenizer.bos_token_id
self.eos_id = self._tokenizer.eos_token_id
self.unk_id = self._tokenizer.unk_token_id
self.pad_id = self._tokenizer.pad_token_id
self.mask_id = self._tokenizer.mask_token_id
self.cls_id = self._tokenizer.cls_token_id
self.generative_vocab_size = len(self._decoder_words)
assert self.init_id < self.generative_vocab_size
assert self.eos_id < self.generative_vocab_size
assert self.unk_id < self.generative_vocab_size
assert self.pad_id < self.generative_vocab_size
# XLM-R mask token is outside of spm dict and is added to fairseq
# see issue https://github.com/huggingface/transformers/issues/2508
if not self._pretrained_name.startswith('xlm-roberta'):
assert self.mask_id < self.generative_vocab_size
assert self.cls_id < self.generative_vocab_size
self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer,
pad_token=self.pad_token, eos_token=self.eos_token)