Merge pull request #20 from stanford-oval/wip/i18n

Support word embeddings for arbitrary languages
This commit is contained in:
Giovanni Campagna 2019-11-03 22:10:16 -08:00 committed by GitHub
commit 9cdb85abbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 22 deletions

View File

@ -104,7 +104,7 @@ def parse(argv):
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
parser.add_argument('--use_fastText', action='store_true', help='use fastText embeddings for encoder')
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings')
parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder')

View File

@ -37,6 +37,7 @@ import sys
from pprint import pformat
from .text import torchtext
from .utils.embeddings import load_embeddings
logger = logging.getLogger(__name__)
@ -45,9 +46,7 @@ def get_args(argv):
parser = ArgumentParser(prog=argv[0])
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
parser.add_argument('--small_glove', action='store_true', help='Cache glove.6B.50d')
parser.add_argument('--large_glove', action='store_true', help='Cache glove.840B.300d')
parser.add_argument('--char', action='store_true', help='Cache character embeddings')
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
args = parser.parse_args(argv[1:])
return args
@ -62,10 +61,4 @@ def main(argv=sys.argv):
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
if args.char:
torchtext.vocab.CharNGram(cache=args.embeddings)
if args.small_glove:
torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
if args.large_glove:
torchtext.vocab.GloVe(cache=args.embeddings)
load_embeddings(args, load_almond_embeddings=False)

View File

@ -467,7 +467,7 @@ class GloVe(Vectors):
class FastText(Vectors):
url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec'
url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{}.vec'
def __init__(self, language="en", **kwargs):
url = self.url_base.format(language)

View File

@ -173,11 +173,13 @@ def load_config_json(args):
'max_generative_vocab', 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char',
'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar',
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
'retrain_encoder_embedding', 'question', 'use_fastText', 'use_google_translate']
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
for r in retrieve:
if r in config:
setattr(args, r, config[r])
elif r == 'locale':
setattr(args, r, 'en')
elif r in ('cove', 'intermediate_cove', 'use_maxmargin_loss', 'small_glove',
'almond_type_embbedings'):
setattr(args, r, False)

View File

@ -66,19 +66,24 @@ class AlmondEmbeddings(torchtext.vocab.Vectors):
self.dim = dim
def load_embeddings(args, logger=_logger):
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
logger.info(f'Getting pretrained word vectors')
final_vectors = []
if args.use_fastText:
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language='fa')]
else:
language = args.locale.split('-')[0]
if language == 'en':
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
if args.small_glove:
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
else:
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings)
vectors = [char_vectors, glove_vectors]
final_vectors.extend(vectors)
if args.almond_type_embeddings:
final_vectors.append(AlmondEmbeddings())
return final_vectors
# elif args.locale == 'zh':
# Chinese word embeddings
else:
# default to fastText
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)]
if load_almond_embeddings and args.almond_type_embeddings:
vectors.append(AlmondEmbeddings())
return vectors