embeddings: allow passing full locale tags as --locale
This way, we don't need to do anything too special in Genie, and we can call decanlp with --locale zh-tw or --locale zh-cn if needed to distinguish
This commit is contained in:
parent
d46353d352
commit
d5aacba674
|
@ -69,7 +69,9 @@ class AlmondEmbeddings(torchtext.vocab.Vectors):
|
|||
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
|
||||
logger.info(f'Getting pretrained word vectors')
|
||||
|
||||
if args.locale == 'en':
|
||||
language = args.locale.split('-')[0]
|
||||
|
||||
if language == 'en':
|
||||
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
|
||||
if args.small_glove:
|
||||
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
|
||||
|
@ -80,7 +82,7 @@ def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
|
|||
# Chinese word embeddings
|
||||
else:
|
||||
# default to fastText
|
||||
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=args.locale)]
|
||||
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)]
|
||||
|
||||
if load_almond_embeddings and args.almond_type_embeddings:
|
||||
vectors.append(AlmondEmbeddings())
|
||||
|
|
Loading…
Reference in New Issue