51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
from six.moves import urllib
|
|
import requests
|
|
|
|
|
|
def reporthook(t):
|
|
"""https://github.com/tqdm/tqdm"""
|
|
last_b = [0]
|
|
|
|
def inner(b=1, bsize=1, tsize=None):
|
|
"""
|
|
b: int, optionala
|
|
Number of blocks just transferred [default: 1].
|
|
bsize: int, optional
|
|
Size of each block (in tqdm units) [default: 1].
|
|
tsize: int, optional
|
|
Total size (in tqdm units). If [default: None] remains unchanged.
|
|
"""
|
|
if tsize is not None:
|
|
t.total = tsize
|
|
t.update((b - last_b[0]) * bsize)
|
|
last_b[0] = b
|
|
return inner
|
|
|
|
|
|
def download_from_url(url, path):
|
|
"""Download file, with logic (from tensor2tensor) for Google Drive"""
|
|
if 'drive.google.com' not in url:
|
|
try:
|
|
return urllib.request.urlretrieve(url, path)
|
|
except:
|
|
res = requests.get(url)
|
|
with open(path, 'wb') as out:
|
|
out.write(res.content)
|
|
print('downloading from Google Drive; may take a few minutes')
|
|
confirm_token = None
|
|
session = requests.Session()
|
|
response = session.get(url, stream=True)
|
|
for k, v in response.cookies.items():
|
|
if k.startswith("download_warning"):
|
|
confirm_token = v
|
|
|
|
if confirm_token:
|
|
url = url + "&confirm=" + confirm_token
|
|
response = session.get(url, stream=True)
|
|
|
|
chunk_size = 16 * 1024
|
|
with open(path, "wb") as f:
|
|
for chunk in response.iter_content(chunk_size):
|
|
if chunk:
|
|
f.write(chunk)
|