mirror of https://github.com/explosion/spaCy.git
add custom download tool (uget), replace wget with uget
This commit is contained in:
parent
9839cd2c0b
commit
bfde91fa49
|
@ -6,7 +6,6 @@ thinc == 3.3
|
||||||
murmurhash == 0.24
|
murmurhash == 0.24
|
||||||
text-unidecode
|
text-unidecode
|
||||||
numpy
|
numpy
|
||||||
wget
|
|
||||||
plac
|
plac
|
||||||
six
|
six
|
||||||
ujson
|
ujson
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -162,7 +162,7 @@ def run_setup(exts):
|
||||||
ext_modules=exts,
|
ext_modules=exts,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
|
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
|
||||||
'thinc == 3.3', "text_unidecode", 'wget', 'plac', 'six',
|
'thinc == 3.3', "text_unidecode", 'plac', 'six',
|
||||||
'ujson', 'cloudpickle'],
|
'ujson', 'cloudpickle'],
|
||||||
setup_requires=["headers_workaround"],
|
setup_requires=["headers_workaround"],
|
||||||
cmdclass = {'build_ext': build_ext_subclass },
|
cmdclass = {'build_ext': build_ext_subclass },
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from os import path
|
from os import path
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
import shutil
|
import shutil
|
||||||
import wget
|
import uget
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
# TODO: Read this from the same source as the setup
|
# TODO: Read this from the same source as the setup
|
||||||
|
@ -13,39 +14,45 @@ AWS_STORE = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com'
|
||||||
|
|
||||||
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
|
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
|
||||||
|
|
||||||
DEST_DIR = path.join(path.dirname(__file__), 'data')
|
DEST_DIR = path.join(path.dirname(path.abspath(__file__)), 'data')
|
||||||
|
|
||||||
def download_file(url, out):
|
|
||||||
wget.download(url, out=out)
|
def download_file(url, dest_dir):
|
||||||
return url.rsplit('/', 1)[1]
|
return uget.download(url, dest_dir, console=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
def install_data(url, dest_dir):
|
def install_data(url, dest_dir):
|
||||||
filename = download_file(url, dest_dir)
|
filename = download_file(url, dest_dir)
|
||||||
t = tarfile.open(path.join(dest_dir, filename))
|
t = tarfile.open(filename)
|
||||||
t.extractall(dest_dir)
|
t.extractall(dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def install_parser_model(url, dest_dir):
|
def install_parser_model(url, dest_dir):
|
||||||
filename = download_file(url, dest_dir)
|
filename = download_file(url, dest_dir)
|
||||||
t = tarfile.open(path.join(dest_dir, filename), mode=":gz")
|
t = tarfile.open(filename, mode=":gz")
|
||||||
t.extractall(path.dirname(__file__))
|
t.extractall(dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def install_dep_vectors(url, dest_dir):
|
def install_dep_vectors(url, dest_dir):
|
||||||
if not os.path.exists(dest_dir):
|
download_file(url, dest_dir)
|
||||||
os.mkdir(dest_dir)
|
|
||||||
|
|
||||||
filename = download_file(url, dest_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main(data_size='all'):
|
@plac.annotations(
|
||||||
|
force=("Force overwrite", "flag", "f", bool),
|
||||||
|
)
|
||||||
|
def main(data_size='all', force=False):
|
||||||
if data_size == 'all':
|
if data_size == 'all':
|
||||||
data_url = ALL_DATA_DIR_URL
|
data_url = ALL_DATA_DIR_URL
|
||||||
elif data_size == 'small':
|
elif data_size == 'small':
|
||||||
data_url = SM_DATA_DIR_URL
|
data_url = SM_DATA_DIR_URL
|
||||||
if path.exists(DEST_DIR):
|
|
||||||
|
if force and path.exists(DEST_DIR):
|
||||||
shutil.rmtree(DEST_DIR)
|
shutil.rmtree(DEST_DIR)
|
||||||
install_data(data_url, path.dirname(DEST_DIR))
|
|
||||||
|
if not os.path.exists(DEST_DIR):
|
||||||
|
os.makedirs(DEST_DIR)
|
||||||
|
|
||||||
|
install_data(data_url, DEST_DIR)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,246 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import io
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from urllib.request import urlopen, Request
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
except ImportError:
|
||||||
|
from urllib2 import urlopen, urlparse, Request, HTTPError
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownContentLengthException(Exception): pass
|
||||||
|
class InvalidChecksumException(Exception): pass
|
||||||
|
class UnsupportedHTTPCodeException(Exception): pass
|
||||||
|
class InvalidOffsetException(Exception): pass
|
||||||
|
class MissingChecksumHeader(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
|
CHUNK_SIZE = 16 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class RateSampler(object):
|
||||||
|
def __init__(self, period=1):
|
||||||
|
self.rate = None
|
||||||
|
self.reset = True
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.reset:
|
||||||
|
self.reset = False
|
||||||
|
self.start = time.time()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
elapsed = time.time() - self.start
|
||||||
|
if elapsed >= self.period:
|
||||||
|
self.reset = True
|
||||||
|
self.rate = float(self.counter) / elapsed
|
||||||
|
|
||||||
|
def update(self, value):
|
||||||
|
self.counter += value
|
||||||
|
|
||||||
|
def format(self, unit="MB"):
|
||||||
|
if self.rate is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
divisor = {'MB': 1048576, 'kB': 1024}
|
||||||
|
return "%0.2f%s/s" % (self.rate / divisor[unit], unit)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeEstimator(object):
|
||||||
|
def __init__(self, cooldown=1):
|
||||||
|
self.cooldown = cooldown
|
||||||
|
self.start = time.time()
|
||||||
|
self.time_left = None
|
||||||
|
|
||||||
|
def update(self, bytes_read, total_size):
|
||||||
|
elapsed = time.time() - self.start
|
||||||
|
if elapsed > self.cooldown:
|
||||||
|
self.time_left = math.ceil(elapsed * total_size /
|
||||||
|
bytes_read - elapsed)
|
||||||
|
|
||||||
|
def format(self):
|
||||||
|
if self.time_left is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = "eta "
|
||||||
|
if self.time_left / 60 >= 1:
|
||||||
|
res += "%dm " % (self.time_left / 60)
|
||||||
|
return res + "%ds" % (self.time_left % 60)
|
||||||
|
|
||||||
|
|
||||||
|
def format_bytes_read(bytes_read, unit="MB"):
|
||||||
|
divisor = {'MB': 1048576, 'kB': 1024}
|
||||||
|
return "%0.2f%s" % (float(bytes_read) / divisor[unit], unit)
|
||||||
|
|
||||||
|
|
||||||
|
def format_percent(bytes_read, total_size):
|
||||||
|
percent = round(bytes_read * 100.0 / total_size, 2)
|
||||||
|
return "%0.2f%%" % percent
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_range(response):
|
||||||
|
content_range = response.headers.get('Content-Range', "").strip()
|
||||||
|
if content_range:
|
||||||
|
m = re.match(r"bytes (\d+)-(\d+)/(\d+)", content_range)
|
||||||
|
if m:
|
||||||
|
return [int(v) for v in m.groups()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_length(response):
|
||||||
|
if 'Content-Length' not in response.headers:
|
||||||
|
raise UnknownContentLengthException
|
||||||
|
return int(response.headers.get('Content-Length').strip())
|
||||||
|
|
||||||
|
|
||||||
|
def get_url_meta(url, checksum_header=None):
|
||||||
|
class HeadRequest(Request):
|
||||||
|
def get_method(self):
|
||||||
|
return "HEAD"
|
||||||
|
|
||||||
|
r = urlopen(HeadRequest(url))
|
||||||
|
res = {'size': get_content_length(r)}
|
||||||
|
|
||||||
|
if checksum_header:
|
||||||
|
value = r.headers.get(checksum_header)
|
||||||
|
if value:
|
||||||
|
res['checksum'] = value
|
||||||
|
|
||||||
|
r.close()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def progress(console, bytes_read, total_size, transfer_rate, eta):
|
||||||
|
fields = [
|
||||||
|
format_bytes_read(bytes_read),
|
||||||
|
format_percent(bytes_read, total_size),
|
||||||
|
transfer_rate.format(),
|
||||||
|
eta.format(),
|
||||||
|
" " * 10,
|
||||||
|
]
|
||||||
|
console.write("Downloaded %s\r" % " ".join(filter(None, fields)))
|
||||||
|
console.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def read_request(request, offset=0, console=None,
|
||||||
|
progress_func=None, write_func=None):
|
||||||
|
# support partial downloads
|
||||||
|
if offset > 0:
|
||||||
|
request.add_header('Range', "bytes=%s-" % offset)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = urlopen(request)
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.code == 416: # Requested Range Not Satisfiable
|
||||||
|
raise InvalidOffsetException
|
||||||
|
|
||||||
|
# TODO add http error handling here
|
||||||
|
raise UnsupportedHTTPCodeException(e.code)
|
||||||
|
|
||||||
|
total_size = get_content_length(response) + offset
|
||||||
|
bytes_read = offset
|
||||||
|
|
||||||
|
# sanity checks
|
||||||
|
if response.code == 200: # OK
|
||||||
|
assert offset == 0
|
||||||
|
elif response.code == 206: # Partial content
|
||||||
|
range_start, range_end, range_total = get_content_range(response)
|
||||||
|
assert range_start == offset
|
||||||
|
assert range_total == total_size
|
||||||
|
assert range_end + 1 - range_start == total_size - bytes_read
|
||||||
|
else:
|
||||||
|
raise UnsupportedHTTPCodeException(response.code)
|
||||||
|
|
||||||
|
eta = TimeEstimator()
|
||||||
|
transfer_rate = RateSampler()
|
||||||
|
|
||||||
|
if console:
|
||||||
|
if offset > 0:
|
||||||
|
console.write("Continue downloading...\n")
|
||||||
|
else:
|
||||||
|
console.write("Downloading...\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
with transfer_rate:
|
||||||
|
chunk = response.read(CHUNK_SIZE)
|
||||||
|
if not chunk:
|
||||||
|
if progress_func and console:
|
||||||
|
console.write('\n')
|
||||||
|
break
|
||||||
|
|
||||||
|
bytes_read += len(chunk)
|
||||||
|
|
||||||
|
transfer_rate.update(len(chunk))
|
||||||
|
eta.update(bytes_read - offset, total_size - offset)
|
||||||
|
|
||||||
|
if progress_func and console:
|
||||||
|
progress_func(console, bytes_read, total_size, transfer_rate, eta)
|
||||||
|
|
||||||
|
if write_func:
|
||||||
|
write_func(chunk)
|
||||||
|
|
||||||
|
response.close()
|
||||||
|
assert bytes_read == total_size
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, path=".",
|
||||||
|
checksum=None, checksum_header=None,
|
||||||
|
headers=None, console=None):
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
path = os.path.join(path, url.rsplit('/', 1)[1])
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
|
||||||
|
with io.open(path, "a+b") as f:
|
||||||
|
size = f.tell()
|
||||||
|
|
||||||
|
# update checksum of partially downloaded file
|
||||||
|
if checksum:
|
||||||
|
f.seek(0, os.SEEK_SET)
|
||||||
|
for chunk in iter(lambda: f.read(CHUNK_SIZE), b""):
|
||||||
|
checksum.update(chunk)
|
||||||
|
|
||||||
|
def write(chunk):
|
||||||
|
if checksum:
|
||||||
|
checksum.update(chunk)
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
request = Request(url)
|
||||||
|
|
||||||
|
# request headers
|
||||||
|
if headers:
|
||||||
|
for key, value in headers.items():
|
||||||
|
request.add_header(key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = read_request(request,
|
||||||
|
offset=size,
|
||||||
|
console=console,
|
||||||
|
progress_func=progress,
|
||||||
|
write_func=write)
|
||||||
|
except InvalidOffsetException:
|
||||||
|
response = None
|
||||||
|
|
||||||
|
if checksum:
|
||||||
|
if response:
|
||||||
|
origin_checksum = response.headers.get(checksum_header)
|
||||||
|
else:
|
||||||
|
# check whether file is already complete
|
||||||
|
meta = get_url_meta(url, checksum_header)
|
||||||
|
origin_checksum = meta.get('checksum')
|
||||||
|
|
||||||
|
if origin_checksum is None:
|
||||||
|
raise MissingChecksumHeader
|
||||||
|
|
||||||
|
if checksum.hexdigest() != origin_checksum:
|
||||||
|
raise InvalidChecksumException
|
||||||
|
|
||||||
|
if console:
|
||||||
|
console.write("checksum/sha256 OK\n")
|
||||||
|
|
||||||
|
return path
|
Loading…
Reference in New Issue