mirror of https://github.com/explosion/spaCy.git
Fix loading of GloVe vectors, to address Issue #541
This commit is contained in:
parent
06b83d8f40
commit
5ec32f5d97
|
@ -13,7 +13,6 @@ except NameError:
|
|||
basestring = str
|
||||
|
||||
|
||||
|
||||
set_lang_class(en.English.lang, en.English)
|
||||
set_lang_class(de.German.lang, de.German)
|
||||
set_lang_class(zh.Chinese.lang, zh.Chinese)
|
||||
|
@ -21,13 +20,19 @@ set_lang_class(zh.Chinese.lang, zh.Chinese)
|
|||
|
||||
def load(name, **overrides):
|
||||
target_name, target_version = util.split_data_name(name)
|
||||
path = overrides.get('path', util.get_data_path())
|
||||
path = util.match_best_version(target_name, target_version, path)
|
||||
|
||||
if isinstance(overrides.get('vectors'), basestring):
|
||||
vectors_path = util.match_best_version(overrides.get('vectors'), None, path)
|
||||
overrides['vectors'] = lambda nlp: nlp.vocab.load_vectors_from_bin_loc(
|
||||
vectors_path / 'vocab' / 'vec.bin')
|
||||
data_path = overrides.get('path', util.get_data_path())
|
||||
if target_name == 'en' and 'add_vectors' not in overrides:
|
||||
if 'vectors' in overrides:
|
||||
vec_path = util.match_best_version(overrides['vectors'], None, data_path)
|
||||
if vec_path is None:
|
||||
raise IOError(
|
||||
'Could not load data pack %s from %s' % (overrides['vectors'], data_path))
|
||||
|
||||
else:
|
||||
vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
|
||||
if vec_path is not None:
|
||||
vec_path = vec_path / 'vocab' / 'vec.bin'
|
||||
overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
|
||||
path = util.match_best_version(target_name, target_version, data_path)
|
||||
cls = get_lang_class(target_name)
|
||||
return cls(path=path, **overrides)
|
||||
|
|
|
@ -53,7 +53,11 @@ class BaseDefaults(object):
|
|||
|
||||
@classmethod
|
||||
def add_vectors(cls, nlp=None):
|
||||
return True
|
||||
if nlp is None or nlp.path is None:
|
||||
return False
|
||||
else:
|
||||
vec_path = nlp.path / 'vocab' / 'vec.bin'
|
||||
return lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
|
||||
|
||||
@classmethod
|
||||
def create_tokenizer(cls, nlp=None):
|
||||
|
@ -246,6 +250,11 @@ class Language(object):
|
|||
self.vocab = self.Defaults.create_vocab(self) \
|
||||
if 'vocab' not in overrides \
|
||||
else overrides['vocab']
|
||||
add_vectors = self.Defaults.add_vectors(self) \
|
||||
if 'add_vectors' not in overrides \
|
||||
else overrides['add_vectors']
|
||||
if add_vectors:
|
||||
add_vectors(self.vocab)
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||
if 'tokenizer' not in overrides \
|
||||
else overrides['tokenizer']
|
||||
|
|
|
@ -49,9 +49,13 @@ cdef class Vocab:
|
|||
'''A map container for a language's LexemeC structs.
|
||||
'''
|
||||
@classmethod
|
||||
def load(cls, path, lex_attr_getters=None, vectors=True, lemmatizer=True,
|
||||
def load(cls, path, lex_attr_getters=None, lemmatizer=True,
|
||||
tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs):
|
||||
util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs)
|
||||
if 'vectors' in deprecated_kwargs:
|
||||
raise AttributeError(
|
||||
"vectors argument to Vocab.load() deprecated. "
|
||||
"Install vectors after loading.")
|
||||
if tag_map is True and (path / 'vocab' / 'tag_map.json').exists():
|
||||
with (path / 'vocab' / 'tag_map.json').open() as file_:
|
||||
tag_map = json.load(file_)
|
||||
|
@ -73,15 +77,6 @@ cdef class Vocab:
|
|||
with (path / 'vocab' / 'strings.json').open() as file_:
|
||||
self.strings.load(file_)
|
||||
self.load_lexemes(path / 'vocab' / 'lexemes.bin')
|
||||
|
||||
if vectors is True:
|
||||
vec_path = path / 'vocab' / 'vec.bin'
|
||||
if vec_path.exists():
|
||||
vectors = lambda self_: self_.load_vectors_from_bin_loc(vec_path)
|
||||
else:
|
||||
vectors = lambda self_: 0
|
||||
if vectors:
|
||||
self.vectors_length = vectors(self)
|
||||
return self
|
||||
|
||||
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
|
||||
|
@ -387,10 +382,11 @@ cdef class Vocab:
|
|||
vec_len, len(pieces))
|
||||
orth = self.strings[word_str]
|
||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||
lexeme.vector = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
||||
lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||
|
||||
for i, val_str in enumerate(pieces):
|
||||
lexeme.vector[i] = float(val_str)
|
||||
self.vectors_length = vec_len
|
||||
return vec_len
|
||||
|
||||
def load_vectors_from_bin_loc(self, loc):
|
||||
|
@ -438,6 +434,7 @@ cdef class Vocab:
|
|||
lex.l2_norm = math.sqrt(lex.l2_norm)
|
||||
else:
|
||||
lex.vector = EMPTY_VEC
|
||||
self.vectors_length = vec_len
|
||||
return vec_len
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue