address mask_id issue for XLM-R
This commit is contained in:
parent
1e2dbce017
commit
0acf64bdc3
|
@ -91,19 +91,25 @@ class TransformerNumericalizer(object):
|
|||
self.unk_token = self._tokenizer.unk_token
|
||||
self.pad_token = self._tokenizer.pad_token
|
||||
self.mask_token = self._tokenizer.mask_token
|
||||
self.cls_token = self._tokenizer.cls_token
|
||||
|
||||
self.init_id = self._tokenizer.bos_token_id
|
||||
self.eos_id = self._tokenizer.eos_token_id
|
||||
self.unk_id = self._tokenizer.unk_token_id
|
||||
self.pad_id = self._tokenizer.pad_token_id
|
||||
self.mask_id = self._tokenizer.mask_token_id
|
||||
self.cls_id = self._tokenizer.cls_token_id
|
||||
self.generative_vocab_size = len(self._decoder_words)
|
||||
|
||||
assert self.init_id < self.generative_vocab_size
|
||||
assert self.eos_id < self.generative_vocab_size
|
||||
assert self.unk_id < self.generative_vocab_size
|
||||
assert self.pad_id < self.generative_vocab_size
|
||||
assert self.mask_id < self.generative_vocab_size
|
||||
# XLM-R mask token is outside of spm dict and is added to fairseq
|
||||
# see issue https://github.com/huggingface/transformers/issues/2508
|
||||
if not self._pretrained_name.startswith('xlm-roberta'):
|
||||
assert self.mask_id < self.generative_vocab_size
|
||||
assert self.cls_id < self.generative_vocab_size
|
||||
|
||||
self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer,
|
||||
pad_token=self.pad_token, eos_token=self.eos_token)
|
||||
|
|
Loading…
Reference in New Issue