From 0acf64bdc3f21c95d8aa166388669483148d00f4 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 24 Mar 2020 16:05:01 -0700 Subject: [PATCH] address mask_id issue for XLM-R --- genienlp/data_utils/numericalizer/transformer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/genienlp/data_utils/numericalizer/transformer.py b/genienlp/data_utils/numericalizer/transformer.py index 3d00a3b5..7ca36254 100644 --- a/genienlp/data_utils/numericalizer/transformer.py +++ b/genienlp/data_utils/numericalizer/transformer.py @@ -91,19 +91,25 @@ class TransformerNumericalizer(object): self.unk_token = self._tokenizer.unk_token self.pad_token = self._tokenizer.pad_token self.mask_token = self._tokenizer.mask_token + self.cls_token = self._tokenizer.cls_token self.init_id = self._tokenizer.bos_token_id self.eos_id = self._tokenizer.eos_token_id self.unk_id = self._tokenizer.unk_token_id self.pad_id = self._tokenizer.pad_token_id self.mask_id = self._tokenizer.mask_token_id + self.cls_id = self._tokenizer.cls_token_id self.generative_vocab_size = len(self._decoder_words) assert self.init_id < self.generative_vocab_size assert self.eos_id < self.generative_vocab_size assert self.unk_id < self.generative_vocab_size assert self.pad_id < self.generative_vocab_size - assert self.mask_id < self.generative_vocab_size + # XLM-R mask token is outside of spm dict and is added to fairseq + # see issue https://github.com/huggingface/transformers/issues/2508 + if not self._pretrained_name.startswith('xlm-roberta'): + assert self.mask_id < self.generative_vocab_size + assert self.cls_id < self.generative_vocab_size self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer, pad_token=self.pad_token, eos_token=self.eos_token)