from six.moves import urllib import requests def reporthook(t): """https://github.com/tqdm/tqdm""" last_b = [0] def inner(b=1, bsize=1, tsize=None): """ b: int, optionala Number of blocks just transferred [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: t.total = tsize t.update((b - last_b[0]) * bsize) last_b[0] = b return inner def download_from_url(url, path): """Download file, with logic (from tensor2tensor) for Google Drive""" if 'drive.google.com' not in url: try: return urllib.request.urlretrieve(url, path) except: res = requests.get(url) with open(path, 'wb') as out: out.write(res.content) print('downloading from Google Drive; may take a few minutes') confirm_token = None session = requests.Session() response = session.get(url, stream=True) for k, v in response.cookies.items(): if k.startswith("download_warning"): confirm_token = v if confirm_token: url = url + "&confirm=" + confirm_token response = session.get(url, stream=True) chunk_size = 16 * 1024 with open(path, "wb") as f: for chunk in response.iter_content(chunk_size): if chunk: f.write(chunk)