mirror of https://github.com/explosion/spaCy.git
Update spacy.load() helper functions
This commit is contained in:
parent
152dc018a6
commit
dd6dc4c120
|
@ -99,28 +99,46 @@ def load_model(name, **overrides):
|
|||
if not data_path or not data_path.exists():
|
||||
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
||||
if isinstance(name, basestring_):
|
||||
if (data_path / name).exists(): # in data dir or shortcut
|
||||
spec = importlib.util.spec_from_file_location('model', data_path / name)
|
||||
cls = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(cls)
|
||||
return cls.load(**overrides)
|
||||
if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut
|
||||
return load_model_from_link(name, **overrides)
|
||||
if is_package(name): # installed as package
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(**overrides)
|
||||
return load_model_from_package(name, **overrides)
|
||||
if Path(name).exists(): # path to model data directory
|
||||
model_path = Path(name)
|
||||
meta = get_package_meta(model_path)
|
||||
cls = get_lang_class(meta['lang'])
|
||||
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
||||
return nlp.from_disk(model_path, **overrides)
|
||||
return load_model_from_path(Path(name), **overrides)
|
||||
elif hasattr(name, 'exists'): # Path or Path-like to model data
|
||||
meta = get_package_meta(name)
|
||||
cls = get_lang_class(meta['lang'])
|
||||
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
||||
return nlp.from_disk(name, **overrides)
|
||||
return load_model_from_path(name, **overrides)
|
||||
raise IOError("Can't find model '%s'" % name)
|
||||
|
||||
|
||||
def load_model_from_link(name, **overrides):
|
||||
"""Load a model from a shortcut link, or directory in spaCy data path."""
|
||||
spec = importlib.util.spec_from_file_location('model', get_data_path() / name)
|
||||
try:
|
||||
cls = importlib.util.module_from_spec(spec)
|
||||
except AttributeError:
|
||||
raise IOError(
|
||||
"Cant' load '%s'. If you're using a shortcut link, make sure it "
|
||||
"points to a valid model package (not just a data directory)." % name)
|
||||
spec.loader.exec_module(cls)
|
||||
return cls.load(**overrides)
|
||||
|
||||
|
||||
def load_model_from_package(name, **overrides):
|
||||
"""Load a model from an installed package."""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(**overrides)
|
||||
|
||||
|
||||
def load_model_from_path(model_path, meta=False, **overrides):
|
||||
"""Load a model from a data directory path. Creates Language class with
|
||||
pipeline from meta.json and then calls from_disk() with path."""
|
||||
if not meta:
|
||||
meta = get_model_meta(model_path)
|
||||
cls = get_lang_class(meta['lang'])
|
||||
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta, **overrides)
|
||||
return nlp.from_disk(model_path)
|
||||
|
||||
|
||||
def load_model_from_init_py(init_file, **overrides):
|
||||
"""Helper function to use in the `load()` method of a model package's
|
||||
__init__.py.
|
||||
|
@ -135,9 +153,7 @@ def load_model_from_init_py(init_file, **overrides):
|
|||
data_path = model_path / data_dir
|
||||
if not model_path.exists():
|
||||
raise ValueError("Can't find model directory: %s" % path2str(data_path))
|
||||
cls = get_lang_class(meta['lang'])
|
||||
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
||||
return nlp.from_disk(data_path, **overrides)
|
||||
return load_model_from_path(data_path, meta, **overrides)
|
||||
|
||||
|
||||
def get_model_meta(path):
|
||||
|
|
Loading…
Reference in New Issue