Fix loading of GloVe vectors, to address Issue #541

This commit is contained in:
Matthew Honnibal 2016-10-20 18:27:48 +02:00
parent 06b83d8f40
commit 5ec32f5d97
3 changed files with 31 additions and 20 deletions

View File

@ -13,7 +13,6 @@ except NameError:
basestring = str basestring = str
set_lang_class(en.English.lang, en.English) set_lang_class(en.English.lang, en.English)
set_lang_class(de.German.lang, de.German) set_lang_class(de.German.lang, de.German)
set_lang_class(zh.Chinese.lang, zh.Chinese) set_lang_class(zh.Chinese.lang, zh.Chinese)
@ -21,13 +20,19 @@ set_lang_class(zh.Chinese.lang, zh.Chinese)
def load(name, **overrides): def load(name, **overrides):
target_name, target_version = util.split_data_name(name) target_name, target_version = util.split_data_name(name)
path = overrides.get('path', util.get_data_path()) data_path = overrides.get('path', util.get_data_path())
path = util.match_best_version(target_name, target_version, path) if target_name == 'en' and 'add_vectors' not in overrides:
if 'vectors' in overrides:
if isinstance(overrides.get('vectors'), basestring): vec_path = util.match_best_version(overrides['vectors'], None, data_path)
vectors_path = util.match_best_version(overrides.get('vectors'), None, path) if vec_path is None:
overrides['vectors'] = lambda nlp: nlp.vocab.load_vectors_from_bin_loc( raise IOError(
vectors_path / 'vocab' / 'vec.bin') 'Could not load data pack %s from %s' % (overrides['vectors'], data_path))
else:
vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
if vec_path is not None:
vec_path = vec_path / 'vocab' / 'vec.bin'
overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
path = util.match_best_version(target_name, target_version, data_path)
cls = get_lang_class(target_name) cls = get_lang_class(target_name)
return cls(path=path, **overrides) return cls(path=path, **overrides)

View File

@ -53,7 +53,11 @@ class BaseDefaults(object):
@classmethod @classmethod
def add_vectors(cls, nlp=None): def add_vectors(cls, nlp=None):
return True if nlp is None or nlp.path is None:
return False
else:
vec_path = nlp.path / 'vocab' / 'vec.bin'
return lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
@classmethod @classmethod
def create_tokenizer(cls, nlp=None): def create_tokenizer(cls, nlp=None):
@ -246,6 +250,11 @@ class Language(object):
self.vocab = self.Defaults.create_vocab(self) \ self.vocab = self.Defaults.create_vocab(self) \
if 'vocab' not in overrides \ if 'vocab' not in overrides \
else overrides['vocab'] else overrides['vocab']
add_vectors = self.Defaults.add_vectors(self) \
if 'add_vectors' not in overrides \
else overrides['add_vectors']
if add_vectors:
add_vectors(self.vocab)
self.tokenizer = self.Defaults.create_tokenizer(self) \ self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \ if 'tokenizer' not in overrides \
else overrides['tokenizer'] else overrides['tokenizer']

View File

@ -49,9 +49,13 @@ cdef class Vocab:
'''A map container for a language's LexemeC structs. '''A map container for a language's LexemeC structs.
''' '''
@classmethod @classmethod
def load(cls, path, lex_attr_getters=None, vectors=True, lemmatizer=True, def load(cls, path, lex_attr_getters=None, lemmatizer=True,
tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs): tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs):
util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs) util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs)
if 'vectors' in deprecated_kwargs:
raise AttributeError(
"vectors argument to Vocab.load() deprecated. "
"Install vectors after loading.")
if tag_map is True and (path / 'vocab' / 'tag_map.json').exists(): if tag_map is True and (path / 'vocab' / 'tag_map.json').exists():
with (path / 'vocab' / 'tag_map.json').open() as file_: with (path / 'vocab' / 'tag_map.json').open() as file_:
tag_map = json.load(file_) tag_map = json.load(file_)
@ -73,15 +77,6 @@ cdef class Vocab:
with (path / 'vocab' / 'strings.json').open() as file_: with (path / 'vocab' / 'strings.json').open() as file_:
self.strings.load(file_) self.strings.load(file_)
self.load_lexemes(path / 'vocab' / 'lexemes.bin') self.load_lexemes(path / 'vocab' / 'lexemes.bin')
if vectors is True:
vec_path = path / 'vocab' / 'vec.bin'
if vec_path.exists():
vectors = lambda self_: self_.load_vectors_from_bin_loc(vec_path)
else:
vectors = lambda self_: 0
if vectors:
self.vectors_length = vectors(self)
return self return self
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None, def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
@ -387,10 +382,11 @@ cdef class Vocab:
vec_len, len(pieces)) vec_len, len(pieces))
orth = self.strings[word_str] orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.vector = <float*>self.mem.alloc(self.vectors_length, sizeof(float)) lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
for i, val_str in enumerate(pieces): for i, val_str in enumerate(pieces):
lexeme.vector[i] = float(val_str) lexeme.vector[i] = float(val_str)
self.vectors_length = vec_len
return vec_len return vec_len
def load_vectors_from_bin_loc(self, loc): def load_vectors_from_bin_loc(self, loc):
@ -438,6 +434,7 @@ cdef class Vocab:
lex.l2_norm = math.sqrt(lex.l2_norm) lex.l2_norm = math.sqrt(lex.l2_norm)
else: else:
lex.vector = EMPTY_VEC lex.vector = EMPTY_VEC
self.vectors_length = vec_len
return vec_len return vec_len