diff --git a/decanlp/arguments.py b/decanlp/arguments.py index 836c8303..343a37f6 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -104,7 +104,7 @@ def parse(argv): parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)') parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ') parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings') - parser.add_argument('--use_fastText', action='store_true', help='use fastText embeddings for encoder') + parser.add_argument('--locale', default='en', help='locale to use for word embeddings') parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings') parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)') parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder') diff --git a/decanlp/cache_embeddings.py b/decanlp/cache_embeddings.py index 70fa0198..5770953f 100644 --- a/decanlp/cache_embeddings.py +++ b/decanlp/cache_embeddings.py @@ -37,6 +37,7 @@ import sys from pprint import pformat from .text import torchtext +from .utils.embeddings import load_embeddings logger = logging.getLogger(__name__) @@ -45,9 +46,7 @@ def get_args(argv): parser = ArgumentParser(prog=argv[0]) parser.add_argument('--seed', default=123, type=int, help='Random seed.') parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.') - parser.add_argument('--small_glove', action='store_true', help='Cache glove.6B.50d') - parser.add_argument('--large_glove', action='store_true', help='Cache glove.840B.300d') - parser.add_argument('--char', action='store_true', help='Cache character embeddings') + parser.add_argument('--locale', default='en', help='locale to use for word embeddings') args = parser.parse_args(argv[1:]) return args @@ -62,10 +61,4 @@ def main(argv=sys.argv): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) - if args.char: - torchtext.vocab.CharNGram(cache=args.embeddings) - if args.small_glove: - torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50) - if args.large_glove: - torchtext.vocab.GloVe(cache=args.embeddings) - + load_embeddings(args, load_almond_embeddings=False) diff --git a/decanlp/text/torchtext/vocab.py b/decanlp/text/torchtext/vocab.py index 3bf10df2..3b5a5f73 100644 --- a/decanlp/text/torchtext/vocab.py +++ b/decanlp/text/torchtext/vocab.py @@ -467,7 +467,7 @@ class GloVe(Vectors): class FastText(Vectors): - url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec' + url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{}.vec' def __init__(self, language="en", **kwargs): url = self.url_base.format(language) diff --git a/decanlp/util.py b/decanlp/util.py index 5c82c920..c6abd684 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -173,11 +173,13 @@ def load_config_json(args): 'max_generative_vocab', 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char', 'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar', 'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm', - 'retrain_encoder_embedding', 'question', 'use_fastText', 'use_google_translate'] + 'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate'] for r in retrieve: if r in config: setattr(args, r, config[r]) + elif r == 'locale': + setattr(args, r, 'en') elif r in ('cove', 'intermediate_cove', 'use_maxmargin_loss', 'small_glove', 'almond_type_embbedings'): setattr(args, r, False) diff --git a/decanlp/utils/embeddings.py b/decanlp/utils/embeddings.py index beb3840a..5b3dfee6 100644 --- a/decanlp/utils/embeddings.py +++ b/decanlp/utils/embeddings.py @@ -66,19 +66,24 @@ class AlmondEmbeddings(torchtext.vocab.Vectors): self.dim = dim -def load_embeddings(args, logger=_logger): +def load_embeddings(args, logger=_logger, load_almond_embeddings=True): logger.info(f'Getting pretrained word vectors') - final_vectors = [] - if args.use_fastText: - vectors = [torchtext.vocab.FastText(cache=args.embeddings, language='fa')] - else: + + language = args.locale.split('-')[0] + + if language == 'en': char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) if args.small_glove: glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50) else: glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings) vectors = [char_vectors, glove_vectors] - final_vectors.extend(vectors) - if args.almond_type_embeddings: - final_vectors.append(AlmondEmbeddings()) - return final_vectors \ No newline at end of file + # elif args.locale == 'zh': + # Chinese word embeddings + else: + # default to fastText + vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)] + + if load_almond_embeddings and args.almond_type_embeddings: + vectors.append(AlmondEmbeddings()) + return vectors