diff --git a/spacy/__init__.py b/spacy/__init__.py index 556027a42..191d5970c 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -2,6 +2,7 @@ from . import util from .en import English -def load(name, via=None): +def load(name, via=None, vectors_name=None): package = util.get_package_by_name(name, via=via) - return English(package=package) + vectors_package = util.get_package_by_name(vectors_name, via=via) + return English(package=package, vectors_package=vectors_package) diff --git a/spacy/language.py b/spacy/language.py index ae8aa4560..157f7d040 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -153,7 +153,7 @@ class Language(object): return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}} @classmethod - def default_vocab(cls, package, get_lex_attr=None): + def default_vocab(cls, package, get_lex_attr=None, vectors_package=None): if get_lex_attr is None: if package.has_file('vocab', 'oov_prob'): with package.open(('vocab', 'oov_prob')) as file_: @@ -162,7 +162,8 @@ class Language(object): else: get_lex_attr = cls.default_lex_attrs() if hasattr(package, 'dir_path'): - return Vocab.from_package(package, get_lex_attr=get_lex_attr) + return Vocab.from_package(package, get_lex_attr=get_lex_attr, + vectors_package=vectors_package) else: return Vocab.load(package, get_lex_attr) @@ -198,7 +199,8 @@ class Language(object): matcher=None, serializer=None, load_vectors=True, - package=None): + package=None, + vectors_package=None): """ a model can be specified: @@ -228,7 +230,7 @@ class Language(object): warn("load_vectors is deprecated", DeprecationWarning) if vocab in (None, True): - vocab = self.default_vocab(package) + vocab = self.default_vocab(package, vectors_package=vectors_package) self.vocab = vocab if tokenizer in (None, True): tokenizer = Tokenizer.from_package(package, self.vocab) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index a0a07f305..de4909f30 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -52,7 +52,7 @@ cdef class Vocab: return cls.from_package(get_package(data_dir), get_lex_attr=get_lex_attr) @classmethod - def from_package(cls, package, get_lex_attr=None): + def from_package(cls, package, get_lex_attr=None, vectors_package=None): tag_map = package.load_json(('vocab', 'tag_map.json'), default={}) lemmatizer = Lemmatizer.from_package(package) @@ -66,7 +66,10 @@ cdef class Vocab: self.strings.load(file_) self.load_lexemes(package.file_path('vocab', 'lexemes.bin')) - if package.has_file('vocab', 'vec.bin'): + if vectors_package and vectors_package.has_file('vocab', 'vec.bin'): + self.vectors_length = self.load_vectors_from_bin_loc( + vectors_package.file_path('vocab', 'vec.bin')) + elif package.has_file('vocab', 'vec.bin'): self.vectors_length = self.load_vectors_from_bin_loc( package.file_path('vocab', 'vec.bin')) return self