diff --git a/Pipfile b/Pipfile index 9f9f1a49..701ad827 100644 --- a/Pipfile +++ b/Pipfile @@ -16,7 +16,7 @@ tensorboardX = "==2.0.*" requests = "~=2.22" transformers = "~=2.3" radam = {git = "https://github.com/LiyuanLucasLiu/RAdam"} -sentencepiece = ">=0.1.83" +sentencepiece = ">=0.1.83,<0.2.0" [requires] python_version = "3.7" diff --git a/genienlp/data_utils/embeddings.py b/genienlp/data_utils/embeddings.py index 1fda67b2..4c391efe 100644 --- a/genienlp/data_utils/embeddings.py +++ b/genienlp/data_utils/embeddings.py @@ -160,6 +160,7 @@ class PretrainedLMEmbedding(torch.nn.Module): def _is_bert(embedding_name): return embedding_name.startswith('bert-') + def _is_xlmr(embedding_name): return embedding_name.startswith('xlm-roberta-') diff --git a/genienlp/data_utils/numericalizer/masked_bert_tokenizer.py b/genienlp/data_utils/numericalizer/masked_tokenizer.py similarity index 100% rename from genienlp/data_utils/numericalizer/masked_bert_tokenizer.py rename to genienlp/data_utils/numericalizer/masked_tokenizer.py diff --git a/genienlp/data_utils/numericalizer/transformer.py b/genienlp/data_utils/numericalizer/transformer.py index 3feb5cbc..b47b0406 100644 --- a/genienlp/data_utils/numericalizer/transformer.py +++ b/genienlp/data_utils/numericalizer/transformer.py @@ -31,7 +31,7 @@ import os import torch from .decoder_vocab import DecoderVocabulary -from .masked_bert_tokenizer import MaskedBertTokenizer, MaskedXLMRobertaTokenizer +from .masked_tokenizer import MaskedBertTokenizer, MaskedXLMRobertaTokenizer from .sequential_field import SequentialField from transformers.tokenization_xlnet import SPIECE_UNDERLINE diff --git a/genienlp/tasks/base_dataset.py b/genienlp/tasks/base_dataset.py index d5dabeb2..511812dc 100644 --- a/genienlp/tasks/base_dataset.py +++ b/genienlp/tasks/base_dataset.py @@ -1,7 +1,7 @@ import os import zipfile import tarfile -from six.moves import urllib +import urllib import requests import torch.utils.data diff --git a/setup.py b/setup.py index d5b4404c..561985ab 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ setuptools.setup( 'pyrouge>=0.1.3', 'sacrebleu~=1.0', 'requests~=2.22', - 'transformers~=2.3' + 'transformers~=2.3', + 'sentencepiece>=0.1.83,<0.2.0' ] )