From 3871007c72003258b5312b291848bcb4767404b3 Mon Sep 17 00:00:00 2001 From: Mark Amery Date: Sun, 20 Nov 2016 15:48:04 +0000 Subject: [PATCH] Let --data-path be specified when running download.py scripts Resolves https://github.com/explosion/spaCy/issues/637 --- spacy/de/download.py | 5 +++-- spacy/download.py | 18 +++++++++++++----- spacy/en/download.py | 7 ++++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/spacy/de/download.py b/spacy/de/download.py index ba57c1d31..4f02f0474 100644 --- a/spacy/de/download.py +++ b/spacy/de/download.py @@ -4,9 +4,10 @@ from ..download import download @plac.annotations( force=("Force overwrite", "flag", "f", bool), + data_path=("Path to download model", "option", "d", str) ) -def main(data_size='all', force=False): - download('de', force) +def main(data_size='all', force=False, data_path=None): + download('de', force=force, data_path=data_path) if __name__ == '__main__': diff --git a/spacy/download.py b/spacy/download.py index 779a15f1e..9b2d65ffc 100644 --- a/spacy/download.py +++ b/spacy/download.py @@ -10,10 +10,19 @@ from . import about from . import util -def download(lang, force=False, fail_on_exist=True): +def download(lang, force=False, fail_on_exist=True, data_path=None): + if not data_path: + data_path = util.get_data_path() + + # spaCy uses pathlib, and util.get_data_path returns a pathlib.Path object, + # but sputnik (which we're using below) doesn't use pathlib and requires + # its data_path parameters to be strings, so we coerce the data_path to a + # str here. + data_path = str(data_path) + try: pkg = sputnik.package(about.__title__, about.__version__, - about.__models__.get(lang, lang)) + about.__models__.get(lang, lang), data_path) if force: shutil.rmtree(pkg.path) elif fail_on_exist: @@ -24,15 +33,14 @@ def download(lang, force=False, fail_on_exist=True): pass package = sputnik.install(about.__title__, about.__version__, - about.__models__.get(lang, lang)) + about.__models__.get(lang, lang), data_path) try: sputnik.package(about.__title__, about.__version__, - about.__models__.get(lang, lang)) + about.__models__.get(lang, lang), data_path) except (PackageNotFoundException, CompatiblePackageNotFoundException): print("Model failed to install. Please run 'python -m " "spacy.%s.download --force'." % lang, file=sys.stderr) sys.exit(1) - data_path = util.get_data_path() print("Model successfully installed to %s" % data_path, file=sys.stderr) diff --git a/spacy/en/download.py b/spacy/en/download.py index bf3815769..7a2d58234 100644 --- a/spacy/en/download.py +++ b/spacy/en/download.py @@ -7,17 +7,18 @@ from .. import about @plac.annotations( force=("Force overwrite", "flag", "f", bool), + data_path=("Path to download model", "option", "d", str) ) -def main(data_size='all', force=False): +def main(data_size='all', force=False, data_path=None): if force: sputnik.purge(about.__title__, about.__version__) if data_size in ('all', 'parser'): print("Downloading parsing model") - download('en', False) + download('en', force=False, data_path=data_path) if data_size in ('all', 'glove'): print("Downloading GloVe vectors") - download('en_glove_cc_300_1m_vectors', False) + download('en_glove_cc_300_1m_vectors', force=False, data_path=data_path) if __name__ == '__main__':