From 5ec32f5d972674fdd1c6c163b2476992e036043e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 20 Oct 2016 18:27:48 +0200 Subject: [PATCH] Fix loading of GloVe vectors, to address Issue #541 --- spacy/__init__.py | 21 +++++++++++++-------- spacy/language.py | 11 ++++++++++- spacy/vocab.pyx | 19 ++++++++----------- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/spacy/__init__.py b/spacy/__init__.py index 662f7f82c..a99d1cd2f 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -13,7 +13,6 @@ except NameError: basestring = str - set_lang_class(en.English.lang, en.English) set_lang_class(de.German.lang, de.German) set_lang_class(zh.Chinese.lang, zh.Chinese) @@ -21,13 +20,19 @@ set_lang_class(zh.Chinese.lang, zh.Chinese) def load(name, **overrides): target_name, target_version = util.split_data_name(name) - path = overrides.get('path', util.get_data_path()) - path = util.match_best_version(target_name, target_version, path) + data_path = overrides.get('path', util.get_data_path()) + if target_name == 'en' and 'add_vectors' not in overrides: + if 'vectors' in overrides: + vec_path = util.match_best_version(overrides['vectors'], None, data_path) + if vec_path is None: + raise IOError( + 'Could not load data pack %s from %s' % (overrides['vectors'], data_path)) - if isinstance(overrides.get('vectors'), basestring): - vectors_path = util.match_best_version(overrides.get('vectors'), None, path) - overrides['vectors'] = lambda nlp: nlp.vocab.load_vectors_from_bin_loc( - vectors_path / 'vocab' / 'vec.bin') - + else: + vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path) + if vec_path is not None: + vec_path = vec_path / 'vocab' / 'vec.bin' + overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path) + path = util.match_best_version(target_name, target_version, data_path) cls = get_lang_class(target_name) return cls(path=path, **overrides) diff --git a/spacy/language.py b/spacy/language.py index 721d86bec..028c0860f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -53,7 +53,11 @@ class BaseDefaults(object): @classmethod def add_vectors(cls, nlp=None): - return True + if nlp is None or nlp.path is None: + return False + else: + vec_path = nlp.path / 'vocab' / 'vec.bin' + return lambda vocab: vocab.load_vectors_from_bin_loc(vec_path) @classmethod def create_tokenizer(cls, nlp=None): @@ -246,6 +250,11 @@ class Language(object): self.vocab = self.Defaults.create_vocab(self) \ if 'vocab' not in overrides \ else overrides['vocab'] + add_vectors = self.Defaults.add_vectors(self) \ + if 'add_vectors' not in overrides \ + else overrides['add_vectors'] + if add_vectors: + add_vectors(self.vocab) self.tokenizer = self.Defaults.create_tokenizer(self) \ if 'tokenizer' not in overrides \ else overrides['tokenizer'] diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 38582eacb..493338079 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -49,9 +49,13 @@ cdef class Vocab: '''A map container for a language's LexemeC structs. ''' @classmethod - def load(cls, path, lex_attr_getters=None, vectors=True, lemmatizer=True, + def load(cls, path, lex_attr_getters=None, lemmatizer=True, tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs): util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs) + if 'vectors' in deprecated_kwargs: + raise AttributeError( + "vectors argument to Vocab.load() deprecated. " + "Install vectors after loading.") if tag_map is True and (path / 'vocab' / 'tag_map.json').exists(): with (path / 'vocab' / 'tag_map.json').open() as file_: tag_map = json.load(file_) @@ -73,15 +77,6 @@ cdef class Vocab: with (path / 'vocab' / 'strings.json').open() as file_: self.strings.load(file_) self.load_lexemes(path / 'vocab' / 'lexemes.bin') - - if vectors is True: - vec_path = path / 'vocab' / 'vec.bin' - if vec_path.exists(): - vectors = lambda self_: self_.load_vectors_from_bin_loc(vec_path) - else: - vectors = lambda self_: 0 - if vectors: - self.vectors_length = vectors(self) return self def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None, @@ -387,10 +382,11 @@ cdef class Vocab: vec_len, len(pieces)) orth = self.strings[word_str] lexeme = self.get_by_orth(self.mem, orth) - lexeme.vector = self.mem.alloc(self.vectors_length, sizeof(float)) + lexeme.vector = self.mem.alloc(vec_len, sizeof(float)) for i, val_str in enumerate(pieces): lexeme.vector[i] = float(val_str) + self.vectors_length = vec_len return vec_len def load_vectors_from_bin_loc(self, loc): @@ -438,6 +434,7 @@ cdef class Vocab: lex.l2_norm = math.sqrt(lex.l2_norm) else: lex.vector = EMPTY_VEC + self.vectors_length = vec_len return vec_len