diff --git a/decanlp/utils/embeddings.py b/decanlp/utils/embeddings.py index 1ce32b67..5b3dfee6 100644 --- a/decanlp/utils/embeddings.py +++ b/decanlp/utils/embeddings.py @@ -69,7 +69,9 @@ class AlmondEmbeddings(torchtext.vocab.Vectors): def load_embeddings(args, logger=_logger, load_almond_embeddings=True): logger.info(f'Getting pretrained word vectors') - if args.locale == 'en': + language = args.locale.split('-')[0] + + if language == 'en': char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) if args.small_glove: glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50) @@ -80,7 +82,7 @@ def load_embeddings(args, logger=_logger, load_almond_embeddings=True): # Chinese word embeddings else: # default to fastText - vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=args.locale)] + vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)] if load_almond_embeddings and args.almond_type_embeddings: vectors.append(AlmondEmbeddings())