mirror of https://github.com/tqdm/tqdm.git
fix EMA estimates
This commit is contained in:
parent
e7d1359027
commit
8ae4fc86d1
|
@ -754,28 +754,29 @@ def test_smoothed_dynamic_min_iters():
|
|||
timer = DiscreteTimer()
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
with tqdm(total=100, file=our_file, miniters=None, mininterval=0,
|
||||
with tqdm(total=100, file=our_file, miniters=None, mininterval=1,
|
||||
smoothing=0.5, maxinterval=0) as t:
|
||||
cpu_timify(t, timer)
|
||||
|
||||
# Increase 10 iterations at once
|
||||
timer.sleep(1)
|
||||
t.update(10)
|
||||
# The next iterations should be partially skipped
|
||||
for _ in _range(2):
|
||||
timer.sleep(1)
|
||||
t.update(4)
|
||||
for _ in _range(20):
|
||||
timer.sleep(1)
|
||||
t.update()
|
||||
|
||||
out = our_file.getvalue()
|
||||
assert t.dynamic_miniters
|
||||
out = our_file.getvalue()
|
||||
assert ' 0%| | 0/100 [00:00<' in out
|
||||
assert '10%' in out
|
||||
assert '14%' not in out
|
||||
assert '18%' in out
|
||||
assert '20%' not in out
|
||||
assert '20%' in out
|
||||
assert '23%' not in out
|
||||
assert '25%' in out
|
||||
assert '30%' not in out
|
||||
assert '32%' in out
|
||||
assert '26%' not in out
|
||||
assert '28%' in out
|
||||
|
||||
|
||||
def test_smoothed_dynamic_min_iters_with_min_interval():
|
||||
|
|
83
tqdm/std.py
83
tqdm/std.py
|
@ -213,6 +213,37 @@ class Bar(object):
|
|||
return self.colour + res + self.COLOUR_RESET if self.colour else res
|
||||
|
||||
|
||||
class EMA(object):
|
||||
"""
|
||||
Exponential moving average: smoothing to give progressively lower
|
||||
weights to older values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
smoothing : float, optional
|
||||
Smoothing factor in range [0, 1], [default: 0.3].
|
||||
Increase to give more weight to recent values.
|
||||
Ranges from 0 (yields old value) to 1 (yields new value).
|
||||
"""
|
||||
def __init__(self, smoothing=0.3):
|
||||
self.alpha = smoothing
|
||||
self.last = 0
|
||||
self.calls = 0
|
||||
|
||||
def __call__(self, x=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
New value to include in EMA.
|
||||
"""
|
||||
beta = 1 - self.alpha
|
||||
if x is not None:
|
||||
self.last = self.alpha * x + beta * self.last
|
||||
self.calls += 1
|
||||
return self.last / (1 - beta ** self.calls) if self.calls else self.last
|
||||
|
||||
|
||||
class tqdm(Comparable):
|
||||
"""
|
||||
Decorate an iterable object, returning an iterator which acts exactly
|
||||
|
@ -295,26 +326,6 @@ class tqdm(Comparable):
|
|||
n = str(n)
|
||||
return f if len(f) < len(n) else n
|
||||
|
||||
def ema(self, x, mu=0, alpha=0.3):
|
||||
"""
|
||||
Exponential moving average: smoothing to give progressively lower
|
||||
weights to older values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
New value to include in EMA.
|
||||
mu : float, optional
|
||||
Previous EMA value.
|
||||
alpha : float, optional
|
||||
Smoothing factor in range [0, 1], [default: 0.3].
|
||||
Increase to give more weight to recent values.
|
||||
Ranges from 0 (yields mu) to 1 (yields x).
|
||||
"""
|
||||
beta = 1 - alpha
|
||||
res = alpha * x + beta * mu
|
||||
return res / (1 - beta ** self.n) if self.n else res
|
||||
|
||||
@staticmethod
|
||||
def status_printer(file):
|
||||
"""
|
||||
|
@ -1048,8 +1059,9 @@ class tqdm(Comparable):
|
|||
self.gui = gui
|
||||
self.dynamic_ncols = dynamic_ncols
|
||||
self.smoothing = smoothing
|
||||
self.avg_dn = 0
|
||||
self.avg_dt = 0
|
||||
self._ema_dn = EMA(smoothing)
|
||||
self._ema_dt = EMA(smoothing)
|
||||
self._ema_miniters = EMA(smoothing)
|
||||
self.bar_format = bar_format
|
||||
self.postfix = None
|
||||
self.colour = colour
|
||||
|
@ -1144,6 +1156,9 @@ class tqdm(Comparable):
|
|||
last_print_n = self.last_print_n
|
||||
n = self.n
|
||||
smoothing = self.smoothing
|
||||
_ema_dn = self._ema_dn
|
||||
_ema_dt = self._ema_dt
|
||||
_ema_miniters = self._ema_miniters
|
||||
time = self._time
|
||||
|
||||
try:
|
||||
|
@ -1159,12 +1174,11 @@ class tqdm(Comparable):
|
|||
if dt >= mininterval:
|
||||
cur_t = time()
|
||||
dn = n - last_print_n
|
||||
self.n = n
|
||||
# EMA (not just overall average)
|
||||
if smoothing and dt and dn:
|
||||
self.avg_dn = self.ema(dn, self.avg_dn, smoothing)
|
||||
self.avg_dt = self.ema(dt, self.avg_dt, smoothing)
|
||||
|
||||
_ema_dn(dn)
|
||||
_ema_dt(dt)
|
||||
self.n = n
|
||||
self.refresh(lock_args=self.lock_args)
|
||||
|
||||
# If no `miniters` was specified, adjust automatically
|
||||
|
@ -1181,10 +1195,9 @@ class tqdm(Comparable):
|
|||
elif smoothing:
|
||||
# EMA-weight miniters to converge
|
||||
# towards the timeframe of mininterval
|
||||
miniters = self.ema(
|
||||
miniters = _ema_miniters(
|
||||
dn * (mininterval / dt
|
||||
if mininterval and dt else 1),
|
||||
miniters, smoothing)
|
||||
if mininterval and dt else 1))
|
||||
else:
|
||||
# Maximum nb of iterations between 2 prints
|
||||
miniters = max(miniters, dn)
|
||||
|
@ -1242,11 +1255,9 @@ class tqdm(Comparable):
|
|||
cur_t = self._time()
|
||||
dn = self.n - self.last_print_n # >= n
|
||||
# elapsed = cur_t - self.start_t
|
||||
# EMA (not just overall average)
|
||||
if self.smoothing and dt and dn:
|
||||
self.avg_dn = self.ema(dn, self.avg_dn, self.smoothing)
|
||||
self.avg_dt = self.ema(dt, self.avg_dt, self.smoothing)
|
||||
|
||||
self._ema_dn(dn)
|
||||
self._ema_dt(dt)
|
||||
self.refresh(lock_args=self.lock_args)
|
||||
|
||||
# If no `miniters` was specified, adjust automatically to the
|
||||
|
@ -1261,9 +1272,9 @@ class tqdm(Comparable):
|
|||
else:
|
||||
self.miniters = dn * self.maxinterval / dt
|
||||
elif self.smoothing:
|
||||
self.miniters = self.ema(
|
||||
self.miniters = self._ema_miniters(
|
||||
dn * (self.mininterval / dt if self.mininterval and dt
|
||||
else 1), self.miniters, self.smoothing)
|
||||
else 1))
|
||||
else:
|
||||
self.miniters = max(self.miniters, dn)
|
||||
|
||||
|
@ -1452,7 +1463,7 @@ class tqdm(Comparable):
|
|||
elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0,
|
||||
ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii,
|
||||
unit=self.unit, unit_scale=self.unit_scale,
|
||||
rate=self.avg_dn / self.avg_dt if self.avg_dt else None,
|
||||
rate=self._ema_dn() / self._ema_dt() if self._ema_dt() else None,
|
||||
bar_format=self.bar_format, postfix=self.postfix,
|
||||
unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour)
|
||||
|
||||
|
|
Loading…
Reference in New Issue