address mask_id issue for XLM-R
This commit is contained in:
parent
1e2dbce017
commit
0acf64bdc3
|
@ -91,19 +91,25 @@ class TransformerNumericalizer(object):
|
||||||
self.unk_token = self._tokenizer.unk_token
|
self.unk_token = self._tokenizer.unk_token
|
||||||
self.pad_token = self._tokenizer.pad_token
|
self.pad_token = self._tokenizer.pad_token
|
||||||
self.mask_token = self._tokenizer.mask_token
|
self.mask_token = self._tokenizer.mask_token
|
||||||
|
self.cls_token = self._tokenizer.cls_token
|
||||||
|
|
||||||
self.init_id = self._tokenizer.bos_token_id
|
self.init_id = self._tokenizer.bos_token_id
|
||||||
self.eos_id = self._tokenizer.eos_token_id
|
self.eos_id = self._tokenizer.eos_token_id
|
||||||
self.unk_id = self._tokenizer.unk_token_id
|
self.unk_id = self._tokenizer.unk_token_id
|
||||||
self.pad_id = self._tokenizer.pad_token_id
|
self.pad_id = self._tokenizer.pad_token_id
|
||||||
self.mask_id = self._tokenizer.mask_token_id
|
self.mask_id = self._tokenizer.mask_token_id
|
||||||
|
self.cls_id = self._tokenizer.cls_token_id
|
||||||
self.generative_vocab_size = len(self._decoder_words)
|
self.generative_vocab_size = len(self._decoder_words)
|
||||||
|
|
||||||
assert self.init_id < self.generative_vocab_size
|
assert self.init_id < self.generative_vocab_size
|
||||||
assert self.eos_id < self.generative_vocab_size
|
assert self.eos_id < self.generative_vocab_size
|
||||||
assert self.unk_id < self.generative_vocab_size
|
assert self.unk_id < self.generative_vocab_size
|
||||||
assert self.pad_id < self.generative_vocab_size
|
assert self.pad_id < self.generative_vocab_size
|
||||||
|
# XLM-R mask token is outside of spm dict and is added to fairseq
|
||||||
|
# see issue https://github.com/huggingface/transformers/issues/2508
|
||||||
|
if not self._pretrained_name.startswith('xlm-roberta'):
|
||||||
assert self.mask_id < self.generative_vocab_size
|
assert self.mask_id < self.generative_vocab_size
|
||||||
|
assert self.cls_id < self.generative_vocab_size
|
||||||
|
|
||||||
self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer,
|
self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer,
|
||||||
pad_token=self.pad_token, eos_token=self.eos_token)
|
pad_token=self.pad_token, eos_token=self.eos_token)
|
||||||
|
|
Loading…
Reference in New Issue