Add tqdm.py

This commit is contained in:
Noam Yorav-Raphael 2013-10-26 22:54:46 +03:00
parent edc7a4a0ad
commit 6ba85d464c
1 changed files with 100 additions and 0 deletions

100
tqdm.py Normal file
View File

@ -0,0 +1,100 @@
__all__ = ['tqdm', 'trange']
import sys
import time
def format_interval(t):
mins, s = divmod(int(t), 60)
h, m = divmod(mins, 60)
if h:
return '%d:%02d:%02d' % (h, m, s)
else:
return '%02d:%02d' % (m, s)
def format_meter(n, total, elapsed):
# n - number of finished iterations
# total - total number of iterations, or None
# elapsed - number of seconds passed since start
if n > total:
total = None
elapsed_str = format_interval(elapsed)
rate = '%5.2f' % (n / elapsed) if elapsed else '?'
if total:
frac = float(n) / total
N_BARS = 10
bar_length = int(frac*N_BARS)
bar = '#'*bar_length + '-'*(N_BARS-bar_length)
percentage = '%3d%%' % (frac * 100)
left_str = format_interval(elapsed / n * (total-n)) if n else '?'
return '|%s| %d/%d %s [elapsed: %s left: %s, %s iters/sec]' % (
bar, n, total, percentage, elapsed_str, left_str, rate)
else:
return '%d [elapsed: %s, %s iters/sec]' % (n, elapsed_str, rate)
class StatusPrinter(object):
def __init__(self):
self.last_printed_len = 0
def print_status(self, s):
sys.stdout.write('\r'+s+' '*max(self.last_printed_len-len(s), 0))
sys.stdout.flush()
self.last_printed_len = len(s)
def tqdm(iterable, desc='', total=None, leave=False, mininterval=0.5, miniters=1):
"""
Get an iterable object, and return an iterator which acts exactly like the
iterable, but prints a progress meter and updates it every time a value is
requested.
'desc' can contain a short string, describing the progress, that is added
in the beginning of the line.
'total' can give the number of expected iterations. If not given,
len(iterable) is used if it is defined.
If leave is False, tqdm deletes its traces from screen after it has finished
iterating over all elements.
If less than mininterval seconds or miniters iterations have passed since
the last progress meter update, it is not updated again.
"""
if total is None:
try:
total = len(iterable)
except TypeError:
total = None
prefix = desc+': ' if desc else ''
sp = StatusPrinter()
sp.print_status(prefix + format_meter(0, total, 0))
start_t = last_print_t = time.time()
last_print_n = 0
n = 0
for obj in iterable:
yield obj
# Now the object was created and processed, so we can print the meter.
n += 1
if n - last_print_n >= miniters:
# We check the counter first, to reduce the overhead of time.time().
cur_t = time.time()
if cur_t - last_print_t >= mininterval:
sp.print_status(prefix + format_meter(n, total, cur_t-start_t))
last_print_n = n
last_print_t = cur_t
if not leave:
sp.print_status('')
sys.stdout.write('\r')
else:
if last_print_n < n:
cur_t = time.time()
sp.print_status(prefix + format_meter(n, total, cur_t-start_t))
def trange(*args, **kwargs):
"""A shortcut for writing tqdm(xrange)"""
return tqdm(xrange(*args), **kwargs)