mirror of https://github.com/tqdm/tqdm.git
fix EMA estimates
This commit is contained in:
parent
e7d1359027
commit
8ae4fc86d1
|
@ -754,28 +754,29 @@ def test_smoothed_dynamic_min_iters():
|
||||||
timer = DiscreteTimer()
|
timer = DiscreteTimer()
|
||||||
|
|
||||||
with closing(StringIO()) as our_file:
|
with closing(StringIO()) as our_file:
|
||||||
with tqdm(total=100, file=our_file, miniters=None, mininterval=0,
|
with tqdm(total=100, file=our_file, miniters=None, mininterval=1,
|
||||||
smoothing=0.5, maxinterval=0) as t:
|
smoothing=0.5, maxinterval=0) as t:
|
||||||
cpu_timify(t, timer)
|
cpu_timify(t, timer)
|
||||||
|
|
||||||
# Increase 10 iterations at once
|
# Increase 10 iterations at once
|
||||||
|
timer.sleep(1)
|
||||||
t.update(10)
|
t.update(10)
|
||||||
# The next iterations should be partially skipped
|
# The next iterations should be partially skipped
|
||||||
for _ in _range(2):
|
for _ in _range(2):
|
||||||
|
timer.sleep(1)
|
||||||
t.update(4)
|
t.update(4)
|
||||||
for _ in _range(20):
|
for _ in _range(20):
|
||||||
|
timer.sleep(1)
|
||||||
t.update()
|
t.update()
|
||||||
|
|
||||||
out = our_file.getvalue()
|
|
||||||
assert t.dynamic_miniters
|
assert t.dynamic_miniters
|
||||||
|
out = our_file.getvalue()
|
||||||
assert ' 0%| | 0/100 [00:00<' in out
|
assert ' 0%| | 0/100 [00:00<' in out
|
||||||
assert '10%' in out
|
assert '20%' in out
|
||||||
assert '14%' not in out
|
assert '23%' not in out
|
||||||
assert '18%' in out
|
|
||||||
assert '20%' not in out
|
|
||||||
assert '25%' in out
|
assert '25%' in out
|
||||||
assert '30%' not in out
|
assert '26%' not in out
|
||||||
assert '32%' in out
|
assert '28%' in out
|
||||||
|
|
||||||
|
|
||||||
def test_smoothed_dynamic_min_iters_with_min_interval():
|
def test_smoothed_dynamic_min_iters_with_min_interval():
|
||||||
|
|
83
tqdm/std.py
83
tqdm/std.py
|
@ -213,6 +213,37 @@ class Bar(object):
|
||||||
return self.colour + res + self.COLOUR_RESET if self.colour else res
|
return self.colour + res + self.COLOUR_RESET if self.colour else res
|
||||||
|
|
||||||
|
|
||||||
|
class EMA(object):
|
||||||
|
"""
|
||||||
|
Exponential moving average: smoothing to give progressively lower
|
||||||
|
weights to older values.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
smoothing : float, optional
|
||||||
|
Smoothing factor in range [0, 1], [default: 0.3].
|
||||||
|
Increase to give more weight to recent values.
|
||||||
|
Ranges from 0 (yields old value) to 1 (yields new value).
|
||||||
|
"""
|
||||||
|
def __init__(self, smoothing=0.3):
|
||||||
|
self.alpha = smoothing
|
||||||
|
self.last = 0
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def __call__(self, x=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : float
|
||||||
|
New value to include in EMA.
|
||||||
|
"""
|
||||||
|
beta = 1 - self.alpha
|
||||||
|
if x is not None:
|
||||||
|
self.last = self.alpha * x + beta * self.last
|
||||||
|
self.calls += 1
|
||||||
|
return self.last / (1 - beta ** self.calls) if self.calls else self.last
|
||||||
|
|
||||||
|
|
||||||
class tqdm(Comparable):
|
class tqdm(Comparable):
|
||||||
"""
|
"""
|
||||||
Decorate an iterable object, returning an iterator which acts exactly
|
Decorate an iterable object, returning an iterator which acts exactly
|
||||||
|
@ -295,26 +326,6 @@ class tqdm(Comparable):
|
||||||
n = str(n)
|
n = str(n)
|
||||||
return f if len(f) < len(n) else n
|
return f if len(f) < len(n) else n
|
||||||
|
|
||||||
def ema(self, x, mu=0, alpha=0.3):
|
|
||||||
"""
|
|
||||||
Exponential moving average: smoothing to give progressively lower
|
|
||||||
weights to older values.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : float
|
|
||||||
New value to include in EMA.
|
|
||||||
mu : float, optional
|
|
||||||
Previous EMA value.
|
|
||||||
alpha : float, optional
|
|
||||||
Smoothing factor in range [0, 1], [default: 0.3].
|
|
||||||
Increase to give more weight to recent values.
|
|
||||||
Ranges from 0 (yields mu) to 1 (yields x).
|
|
||||||
"""
|
|
||||||
beta = 1 - alpha
|
|
||||||
res = alpha * x + beta * mu
|
|
||||||
return res / (1 - beta ** self.n) if self.n else res
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def status_printer(file):
|
def status_printer(file):
|
||||||
"""
|
"""
|
||||||
|
@ -1048,8 +1059,9 @@ class tqdm(Comparable):
|
||||||
self.gui = gui
|
self.gui = gui
|
||||||
self.dynamic_ncols = dynamic_ncols
|
self.dynamic_ncols = dynamic_ncols
|
||||||
self.smoothing = smoothing
|
self.smoothing = smoothing
|
||||||
self.avg_dn = 0
|
self._ema_dn = EMA(smoothing)
|
||||||
self.avg_dt = 0
|
self._ema_dt = EMA(smoothing)
|
||||||
|
self._ema_miniters = EMA(smoothing)
|
||||||
self.bar_format = bar_format
|
self.bar_format = bar_format
|
||||||
self.postfix = None
|
self.postfix = None
|
||||||
self.colour = colour
|
self.colour = colour
|
||||||
|
@ -1144,6 +1156,9 @@ class tqdm(Comparable):
|
||||||
last_print_n = self.last_print_n
|
last_print_n = self.last_print_n
|
||||||
n = self.n
|
n = self.n
|
||||||
smoothing = self.smoothing
|
smoothing = self.smoothing
|
||||||
|
_ema_dn = self._ema_dn
|
||||||
|
_ema_dt = self._ema_dt
|
||||||
|
_ema_miniters = self._ema_miniters
|
||||||
time = self._time
|
time = self._time
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1159,12 +1174,11 @@ class tqdm(Comparable):
|
||||||
if dt >= mininterval:
|
if dt >= mininterval:
|
||||||
cur_t = time()
|
cur_t = time()
|
||||||
dn = n - last_print_n
|
dn = n - last_print_n
|
||||||
self.n = n
|
|
||||||
# EMA (not just overall average)
|
# EMA (not just overall average)
|
||||||
if smoothing and dt and dn:
|
if smoothing and dt and dn:
|
||||||
self.avg_dn = self.ema(dn, self.avg_dn, smoothing)
|
_ema_dn(dn)
|
||||||
self.avg_dt = self.ema(dt, self.avg_dt, smoothing)
|
_ema_dt(dt)
|
||||||
|
self.n = n
|
||||||
self.refresh(lock_args=self.lock_args)
|
self.refresh(lock_args=self.lock_args)
|
||||||
|
|
||||||
# If no `miniters` was specified, adjust automatically
|
# If no `miniters` was specified, adjust automatically
|
||||||
|
@ -1181,10 +1195,9 @@ class tqdm(Comparable):
|
||||||
elif smoothing:
|
elif smoothing:
|
||||||
# EMA-weight miniters to converge
|
# EMA-weight miniters to converge
|
||||||
# towards the timeframe of mininterval
|
# towards the timeframe of mininterval
|
||||||
miniters = self.ema(
|
miniters = _ema_miniters(
|
||||||
dn * (mininterval / dt
|
dn * (mininterval / dt
|
||||||
if mininterval and dt else 1),
|
if mininterval and dt else 1))
|
||||||
miniters, smoothing)
|
|
||||||
else:
|
else:
|
||||||
# Maximum nb of iterations between 2 prints
|
# Maximum nb of iterations between 2 prints
|
||||||
miniters = max(miniters, dn)
|
miniters = max(miniters, dn)
|
||||||
|
@ -1242,11 +1255,9 @@ class tqdm(Comparable):
|
||||||
cur_t = self._time()
|
cur_t = self._time()
|
||||||
dn = self.n - self.last_print_n # >= n
|
dn = self.n - self.last_print_n # >= n
|
||||||
# elapsed = cur_t - self.start_t
|
# elapsed = cur_t - self.start_t
|
||||||
# EMA (not just overall average)
|
|
||||||
if self.smoothing and dt and dn:
|
if self.smoothing and dt and dn:
|
||||||
self.avg_dn = self.ema(dn, self.avg_dn, self.smoothing)
|
self._ema_dn(dn)
|
||||||
self.avg_dt = self.ema(dt, self.avg_dt, self.smoothing)
|
self._ema_dt(dt)
|
||||||
|
|
||||||
self.refresh(lock_args=self.lock_args)
|
self.refresh(lock_args=self.lock_args)
|
||||||
|
|
||||||
# If no `miniters` was specified, adjust automatically to the
|
# If no `miniters` was specified, adjust automatically to the
|
||||||
|
@ -1261,9 +1272,9 @@ class tqdm(Comparable):
|
||||||
else:
|
else:
|
||||||
self.miniters = dn * self.maxinterval / dt
|
self.miniters = dn * self.maxinterval / dt
|
||||||
elif self.smoothing:
|
elif self.smoothing:
|
||||||
self.miniters = self.ema(
|
self.miniters = self._ema_miniters(
|
||||||
dn * (self.mininterval / dt if self.mininterval and dt
|
dn * (self.mininterval / dt if self.mininterval and dt
|
||||||
else 1), self.miniters, self.smoothing)
|
else 1))
|
||||||
else:
|
else:
|
||||||
self.miniters = max(self.miniters, dn)
|
self.miniters = max(self.miniters, dn)
|
||||||
|
|
||||||
|
@ -1452,7 +1463,7 @@ class tqdm(Comparable):
|
||||||
elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0,
|
elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0,
|
||||||
ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii,
|
ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii,
|
||||||
unit=self.unit, unit_scale=self.unit_scale,
|
unit=self.unit, unit_scale=self.unit_scale,
|
||||||
rate=self.avg_dn / self.avg_dt if self.avg_dt else None,
|
rate=self._ema_dn() / self._ema_dt() if self._ema_dt() else None,
|
||||||
bar_format=self.bar_format, postfix=self.postfix,
|
bar_format=self.bar_format, postfix=self.postfix,
|
||||||
unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour)
|
unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue