From bfde91fa49c5e91686081a6fb6984de537138ab6 Mon Sep 17 00:00:00 2001 From: Henning Peters Date: Sun, 18 Oct 2015 12:35:04 +0200 Subject: [PATCH] add custom download tool (uget), replace wget with uget --- requirements.txt | 1 - setup.py | 2 +- spacy/en/download.py | 37 ++++--- spacy/en/uget.py | 246 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 269 insertions(+), 17 deletions(-) create mode 100644 spacy/en/uget.py diff --git a/requirements.txt b/requirements.txt index dcf2c6dea..146af7e85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ thinc == 3.3 murmurhash == 0.24 text-unidecode numpy -wget plac six ujson diff --git a/setup.py b/setup.py index 56b403b95..8070193c2 100644 --- a/setup.py +++ b/setup.py @@ -162,7 +162,7 @@ def run_setup(exts): ext_modules=exts, license="MIT", install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42', - 'thinc == 3.3', "text_unidecode", 'wget', 'plac', 'six', + 'thinc == 3.3', "text_unidecode", 'plac', 'six', 'ujson', 'cloudpickle'], setup_requires=["headers_workaround"], cmdclass = {'build_ext': build_ext_subclass }, diff --git a/spacy/en/download.py b/spacy/en/download.py index 91f31565b..cfceb6590 100644 --- a/spacy/en/download.py +++ b/spacy/en/download.py @@ -1,9 +1,10 @@ from __future__ import print_function from os import path +import sys import os import tarfile import shutil -import wget +import uget import plac # TODO: Read this from the same source as the setup @@ -13,39 +14,45 @@ AWS_STORE = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com' ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION) -DEST_DIR = path.join(path.dirname(__file__), 'data') +DEST_DIR = path.join(path.dirname(path.abspath(__file__)), 'data') -def download_file(url, out): - wget.download(url, out=out) - return url.rsplit('/', 1)[1] + +def download_file(url, dest_dir): + return uget.download(url, dest_dir, console=sys.stdout) def install_data(url, dest_dir): filename = download_file(url, dest_dir) - t = tarfile.open(path.join(dest_dir, filename)) + t = tarfile.open(filename) t.extractall(dest_dir) + def install_parser_model(url, dest_dir): filename = download_file(url, dest_dir) - t = tarfile.open(path.join(dest_dir, filename), mode=":gz") - t.extractall(path.dirname(__file__)) + t = tarfile.open(filename, mode=":gz") + t.extractall(dest_dir) def install_dep_vectors(url, dest_dir): - if not os.path.exists(dest_dir): - os.mkdir(dest_dir) - - filename = download_file(url, dest_dir) + download_file(url, dest_dir) -def main(data_size='all'): +@plac.annotations( + force=("Force overwrite", "flag", "f", bool), +) +def main(data_size='all', force=False): if data_size == 'all': data_url = ALL_DATA_DIR_URL elif data_size == 'small': data_url = SM_DATA_DIR_URL - if path.exists(DEST_DIR): + + if force and path.exists(DEST_DIR): shutil.rmtree(DEST_DIR) - install_data(data_url, path.dirname(DEST_DIR)) + + if not os.path.exists(DEST_DIR): + os.makedirs(DEST_DIR) + + install_data(data_url, DEST_DIR) if __name__ == '__main__': diff --git a/spacy/en/uget.py b/spacy/en/uget.py new file mode 100644 index 000000000..3cdf6c8d6 --- /dev/null +++ b/spacy/en/uget.py @@ -0,0 +1,246 @@ +import os +import time +import io +import math +import re + +try: + from urllib.parse import urlparse + from urllib.request import urlopen, Request + from urllib.error import HTTPError +except ImportError: + from urllib2 import urlopen, urlparse, Request, HTTPError + + +class UnknownContentLengthException(Exception): pass +class InvalidChecksumException(Exception): pass +class UnsupportedHTTPCodeException(Exception): pass +class InvalidOffsetException(Exception): pass +class MissingChecksumHeader(Exception): pass + + +CHUNK_SIZE = 16 * 1024 + + +class RateSampler(object): + def __init__(self, period=1): + self.rate = None + self.reset = True + self.period = period + + def __enter__(self): + if self.reset: + self.reset = False + self.start = time.time() + self.counter = 0 + + def __exit__(self, type, value, traceback): + elapsed = time.time() - self.start + if elapsed >= self.period: + self.reset = True + self.rate = float(self.counter) / elapsed + + def update(self, value): + self.counter += value + + def format(self, unit="MB"): + if self.rate is None: + return None + + divisor = {'MB': 1048576, 'kB': 1024} + return "%0.2f%s/s" % (self.rate / divisor[unit], unit) + + +class TimeEstimator(object): + def __init__(self, cooldown=1): + self.cooldown = cooldown + self.start = time.time() + self.time_left = None + + def update(self, bytes_read, total_size): + elapsed = time.time() - self.start + if elapsed > self.cooldown: + self.time_left = math.ceil(elapsed * total_size / + bytes_read - elapsed) + + def format(self): + if self.time_left is None: + return None + + res = "eta " + if self.time_left / 60 >= 1: + res += "%dm " % (self.time_left / 60) + return res + "%ds" % (self.time_left % 60) + + +def format_bytes_read(bytes_read, unit="MB"): + divisor = {'MB': 1048576, 'kB': 1024} + return "%0.2f%s" % (float(bytes_read) / divisor[unit], unit) + + +def format_percent(bytes_read, total_size): + percent = round(bytes_read * 100.0 / total_size, 2) + return "%0.2f%%" % percent + + +def get_content_range(response): + content_range = response.headers.get('Content-Range', "").strip() + if content_range: + m = re.match(r"bytes (\d+)-(\d+)/(\d+)", content_range) + if m: + return [int(v) for v in m.groups()] + + +def get_content_length(response): + if 'Content-Length' not in response.headers: + raise UnknownContentLengthException + return int(response.headers.get('Content-Length').strip()) + + +def get_url_meta(url, checksum_header=None): + class HeadRequest(Request): + def get_method(self): + return "HEAD" + + r = urlopen(HeadRequest(url)) + res = {'size': get_content_length(r)} + + if checksum_header: + value = r.headers.get(checksum_header) + if value: + res['checksum'] = value + + r.close() + return res + + +def progress(console, bytes_read, total_size, transfer_rate, eta): + fields = [ + format_bytes_read(bytes_read), + format_percent(bytes_read, total_size), + transfer_rate.format(), + eta.format(), + " " * 10, + ] + console.write("Downloaded %s\r" % " ".join(filter(None, fields))) + console.flush() + + +def read_request(request, offset=0, console=None, + progress_func=None, write_func=None): + # support partial downloads + if offset > 0: + request.add_header('Range', "bytes=%s-" % offset) + + try: + response = urlopen(request) + except HTTPError as e: + if e.code == 416: # Requested Range Not Satisfiable + raise InvalidOffsetException + + # TODO add http error handling here + raise UnsupportedHTTPCodeException(e.code) + + total_size = get_content_length(response) + offset + bytes_read = offset + + # sanity checks + if response.code == 200: # OK + assert offset == 0 + elif response.code == 206: # Partial content + range_start, range_end, range_total = get_content_range(response) + assert range_start == offset + assert range_total == total_size + assert range_end + 1 - range_start == total_size - bytes_read + else: + raise UnsupportedHTTPCodeException(response.code) + + eta = TimeEstimator() + transfer_rate = RateSampler() + + if console: + if offset > 0: + console.write("Continue downloading...\n") + else: + console.write("Downloading...\n") + + while True: + with transfer_rate: + chunk = response.read(CHUNK_SIZE) + if not chunk: + if progress_func and console: + console.write('\n') + break + + bytes_read += len(chunk) + + transfer_rate.update(len(chunk)) + eta.update(bytes_read - offset, total_size - offset) + + if progress_func and console: + progress_func(console, bytes_read, total_size, transfer_rate, eta) + + if write_func: + write_func(chunk) + + response.close() + assert bytes_read == total_size + return response + + +def download(url, path=".", + checksum=None, checksum_header=None, + headers=None, console=None): + + if os.path.isdir(path): + path = os.path.join(path, url.rsplit('/', 1)[1]) + path = os.path.abspath(path) + + with io.open(path, "a+b") as f: + size = f.tell() + + # update checksum of partially downloaded file + if checksum: + f.seek(0, os.SEEK_SET) + for chunk in iter(lambda: f.read(CHUNK_SIZE), b""): + checksum.update(chunk) + + def write(chunk): + if checksum: + checksum.update(chunk) + f.write(chunk) + + request = Request(url) + + # request headers + if headers: + for key, value in headers.items(): + request.add_header(key, value) + + try: + response = read_request(request, + offset=size, + console=console, + progress_func=progress, + write_func=write) + except InvalidOffsetException: + response = None + + if checksum: + if response: + origin_checksum = response.headers.get(checksum_header) + else: + # check whether file is already complete + meta = get_url_meta(url, checksum_header) + origin_checksum = meta.get('checksum') + + if origin_checksum is None: + raise MissingChecksumHeader + + if checksum.hexdigest() != origin_checksum: + raise InvalidChecksumException + + if console: + console.write("checksum/sha256 OK\n") + + return path