genienlp/text/torchtext/utils.py

51 lines
1.5 KiB
Python

from six.moves import urllib
import requests
def reporthook(t):
"""https://github.com/tqdm/tqdm"""
last_b = [0]
def inner(b=1, bsize=1, tsize=None):
"""
b: int, optionala
Number of blocks just transferred [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner
def download_from_url(url, path):
"""Download file, with logic (from tensor2tensor) for Google Drive"""
if 'drive.google.com' not in url:
try:
return urllib.request.urlretrieve(url, path)
except:
res = requests.get(url)
with open(path, 'wb') as out:
out.write(res.content)
print('downloading from Google Drive; may take a few minutes')
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v
if confirm_token:
url = url + "&confirm=" + confirm_token
response = session.get(url, stream=True)
chunk_size = 16 * 1024
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)