embeddings: allow passing full locale tags as --locale

This way, we don't need to do anything too special in Genie,
and we can call decanlp with --locale zh-tw or --locale zh-cn
if needed to distinguish
This commit is contained in:
Giovanni Campagna 2019-11-01 17:36:36 -07:00
parent d46353d352
commit d5aacba674
1 changed files with 4 additions and 2 deletions

View File

@ -69,7 +69,9 @@ class AlmondEmbeddings(torchtext.vocab.Vectors):
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
logger.info(f'Getting pretrained word vectors')
if args.locale == 'en':
language = args.locale.split('-')[0]
if language == 'en':
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
if args.small_glove:
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
@ -80,7 +82,7 @@ def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
# Chinese word embeddings
else:
# default to fastText
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=args.locale)]
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)]
if load_almond_embeddings and args.almond_type_embeddings:
vectors.append(AlmondEmbeddings())