Merge pull request #20 from stanford-oval/wip/i18n
Support word embeddings for arbitrary languages
This commit is contained in:
commit
9cdb85abbb
|
@ -104,7 +104,7 @@ def parse(argv):
|
|||
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
|
||||
parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
|
||||
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
|
||||
parser.add_argument('--use_fastText', action='store_true', help='use fastText embeddings for encoder')
|
||||
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
|
||||
parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings')
|
||||
parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
|
||||
parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder')
|
||||
|
|
|
@ -37,6 +37,7 @@ import sys
|
|||
from pprint import pformat
|
||||
|
||||
from .text import torchtext
|
||||
from .utils.embeddings import load_embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,9 +46,7 @@ def get_args(argv):
|
|||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||
parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
|
||||
parser.add_argument('--small_glove', action='store_true', help='Cache glove.6B.50d')
|
||||
parser.add_argument('--large_glove', action='store_true', help='Cache glove.840B.300d')
|
||||
parser.add_argument('--char', action='store_true', help='Cache character embeddings')
|
||||
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
|
||||
|
||||
args = parser.parse_args(argv[1:])
|
||||
return args
|
||||
|
@ -62,10 +61,4 @@ def main(argv=sys.argv):
|
|||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
if args.char:
|
||||
torchtext.vocab.CharNGram(cache=args.embeddings)
|
||||
if args.small_glove:
|
||||
torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
|
||||
if args.large_glove:
|
||||
torchtext.vocab.GloVe(cache=args.embeddings)
|
||||
|
||||
load_embeddings(args, load_almond_embeddings=False)
|
||||
|
|
|
@ -467,7 +467,7 @@ class GloVe(Vectors):
|
|||
|
||||
class FastText(Vectors):
|
||||
|
||||
url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec'
|
||||
url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{}.vec'
|
||||
|
||||
def __init__(self, language="en", **kwargs):
|
||||
url = self.url_base.format(language)
|
||||
|
|
|
@ -173,11 +173,13 @@ def load_config_json(args):
|
|||
'max_generative_vocab', 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char',
|
||||
'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar',
|
||||
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
|
||||
'retrain_encoder_embedding', 'question', 'use_fastText', 'use_google_translate']
|
||||
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
|
||||
|
||||
for r in retrieve:
|
||||
if r in config:
|
||||
setattr(args, r, config[r])
|
||||
elif r == 'locale':
|
||||
setattr(args, r, 'en')
|
||||
elif r in ('cove', 'intermediate_cove', 'use_maxmargin_loss', 'small_glove',
|
||||
'almond_type_embbedings'):
|
||||
setattr(args, r, False)
|
||||
|
|
|
@ -66,19 +66,24 @@ class AlmondEmbeddings(torchtext.vocab.Vectors):
|
|||
self.dim = dim
|
||||
|
||||
|
||||
def load_embeddings(args, logger=_logger):
|
||||
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
|
||||
logger.info(f'Getting pretrained word vectors')
|
||||
final_vectors = []
|
||||
if args.use_fastText:
|
||||
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language='fa')]
|
||||
else:
|
||||
|
||||
language = args.locale.split('-')[0]
|
||||
|
||||
if language == 'en':
|
||||
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
|
||||
if args.small_glove:
|
||||
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings, name="6B", dim=50)
|
||||
else:
|
||||
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings)
|
||||
vectors = [char_vectors, glove_vectors]
|
||||
final_vectors.extend(vectors)
|
||||
if args.almond_type_embeddings:
|
||||
final_vectors.append(AlmondEmbeddings())
|
||||
return final_vectors
|
||||
# elif args.locale == 'zh':
|
||||
# Chinese word embeddings
|
||||
else:
|
||||
# default to fastText
|
||||
vectors = [torchtext.vocab.FastText(cache=args.embeddings, language=language)]
|
||||
|
||||
if load_almond_embeddings and args.almond_type_embeddings:
|
||||
vectors.append(AlmondEmbeddings())
|
||||
return vectors
|
||||
|
|
Loading…
Reference in New Issue