fix EMA estimates

This commit is contained in:
Casper da Costa-Luis 2020-12-24 18:56:30 +00:00
parent e7d1359027
commit 8ae4fc86d1
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
2 changed files with 56 additions and 44 deletions

View File

@ -754,28 +754,29 @@ def test_smoothed_dynamic_min_iters():
timer = DiscreteTimer() timer = DiscreteTimer()
with closing(StringIO()) as our_file: with closing(StringIO()) as our_file:
with tqdm(total=100, file=our_file, miniters=None, mininterval=0, with tqdm(total=100, file=our_file, miniters=None, mininterval=1,
smoothing=0.5, maxinterval=0) as t: smoothing=0.5, maxinterval=0) as t:
cpu_timify(t, timer) cpu_timify(t, timer)
# Increase 10 iterations at once # Increase 10 iterations at once
timer.sleep(1)
t.update(10) t.update(10)
# The next iterations should be partially skipped # The next iterations should be partially skipped
for _ in _range(2): for _ in _range(2):
timer.sleep(1)
t.update(4) t.update(4)
for _ in _range(20): for _ in _range(20):
timer.sleep(1)
t.update() t.update()
out = our_file.getvalue()
assert t.dynamic_miniters assert t.dynamic_miniters
out = our_file.getvalue()
assert ' 0%| | 0/100 [00:00<' in out assert ' 0%| | 0/100 [00:00<' in out
assert '10%' in out assert '20%' in out
assert '14%' not in out assert '23%' not in out
assert '18%' in out
assert '20%' not in out
assert '25%' in out assert '25%' in out
assert '30%' not in out assert '26%' not in out
assert '32%' in out assert '28%' in out
def test_smoothed_dynamic_min_iters_with_min_interval(): def test_smoothed_dynamic_min_iters_with_min_interval():

View File

@ -213,6 +213,37 @@ class Bar(object):
return self.colour + res + self.COLOUR_RESET if self.colour else res return self.colour + res + self.COLOUR_RESET if self.colour else res
class EMA(object):
"""
Exponential moving average: smoothing to give progressively lower
weights to older values.
Parameters
----------
smoothing : float, optional
Smoothing factor in range [0, 1], [default: 0.3].
Increase to give more weight to recent values.
Ranges from 0 (yields old value) to 1 (yields new value).
"""
def __init__(self, smoothing=0.3):
self.alpha = smoothing
self.last = 0
self.calls = 0
def __call__(self, x=None):
"""
Parameters
----------
x : float
New value to include in EMA.
"""
beta = 1 - self.alpha
if x is not None:
self.last = self.alpha * x + beta * self.last
self.calls += 1
return self.last / (1 - beta ** self.calls) if self.calls else self.last
class tqdm(Comparable): class tqdm(Comparable):
""" """
Decorate an iterable object, returning an iterator which acts exactly Decorate an iterable object, returning an iterator which acts exactly
@ -295,26 +326,6 @@ class tqdm(Comparable):
n = str(n) n = str(n)
return f if len(f) < len(n) else n return f if len(f) < len(n) else n
def ema(self, x, mu=0, alpha=0.3):
"""
Exponential moving average: smoothing to give progressively lower
weights to older values.
Parameters
----------
x : float
New value to include in EMA.
mu : float, optional
Previous EMA value.
alpha : float, optional
Smoothing factor in range [0, 1], [default: 0.3].
Increase to give more weight to recent values.
Ranges from 0 (yields mu) to 1 (yields x).
"""
beta = 1 - alpha
res = alpha * x + beta * mu
return res / (1 - beta ** self.n) if self.n else res
@staticmethod @staticmethod
def status_printer(file): def status_printer(file):
""" """
@ -1048,8 +1059,9 @@ class tqdm(Comparable):
self.gui = gui self.gui = gui
self.dynamic_ncols = dynamic_ncols self.dynamic_ncols = dynamic_ncols
self.smoothing = smoothing self.smoothing = smoothing
self.avg_dn = 0 self._ema_dn = EMA(smoothing)
self.avg_dt = 0 self._ema_dt = EMA(smoothing)
self._ema_miniters = EMA(smoothing)
self.bar_format = bar_format self.bar_format = bar_format
self.postfix = None self.postfix = None
self.colour = colour self.colour = colour
@ -1144,6 +1156,9 @@ class tqdm(Comparable):
last_print_n = self.last_print_n last_print_n = self.last_print_n
n = self.n n = self.n
smoothing = self.smoothing smoothing = self.smoothing
_ema_dn = self._ema_dn
_ema_dt = self._ema_dt
_ema_miniters = self._ema_miniters
time = self._time time = self._time
try: try:
@ -1159,12 +1174,11 @@ class tqdm(Comparable):
if dt >= mininterval: if dt >= mininterval:
cur_t = time() cur_t = time()
dn = n - last_print_n dn = n - last_print_n
self.n = n
# EMA (not just overall average) # EMA (not just overall average)
if smoothing and dt and dn: if smoothing and dt and dn:
self.avg_dn = self.ema(dn, self.avg_dn, smoothing) _ema_dn(dn)
self.avg_dt = self.ema(dt, self.avg_dt, smoothing) _ema_dt(dt)
self.n = n
self.refresh(lock_args=self.lock_args) self.refresh(lock_args=self.lock_args)
# If no `miniters` was specified, adjust automatically # If no `miniters` was specified, adjust automatically
@ -1181,10 +1195,9 @@ class tqdm(Comparable):
elif smoothing: elif smoothing:
# EMA-weight miniters to converge # EMA-weight miniters to converge
# towards the timeframe of mininterval # towards the timeframe of mininterval
miniters = self.ema( miniters = _ema_miniters(
dn * (mininterval / dt dn * (mininterval / dt
if mininterval and dt else 1), if mininterval and dt else 1))
miniters, smoothing)
else: else:
# Maximum nb of iterations between 2 prints # Maximum nb of iterations between 2 prints
miniters = max(miniters, dn) miniters = max(miniters, dn)
@ -1242,11 +1255,9 @@ class tqdm(Comparable):
cur_t = self._time() cur_t = self._time()
dn = self.n - self.last_print_n # >= n dn = self.n - self.last_print_n # >= n
# elapsed = cur_t - self.start_t # elapsed = cur_t - self.start_t
# EMA (not just overall average)
if self.smoothing and dt and dn: if self.smoothing and dt and dn:
self.avg_dn = self.ema(dn, self.avg_dn, self.smoothing) self._ema_dn(dn)
self.avg_dt = self.ema(dt, self.avg_dt, self.smoothing) self._ema_dt(dt)
self.refresh(lock_args=self.lock_args) self.refresh(lock_args=self.lock_args)
# If no `miniters` was specified, adjust automatically to the # If no `miniters` was specified, adjust automatically to the
@ -1261,9 +1272,9 @@ class tqdm(Comparable):
else: else:
self.miniters = dn * self.maxinterval / dt self.miniters = dn * self.maxinterval / dt
elif self.smoothing: elif self.smoothing:
self.miniters = self.ema( self.miniters = self._ema_miniters(
dn * (self.mininterval / dt if self.mininterval and dt dn * (self.mininterval / dt if self.mininterval and dt
else 1), self.miniters, self.smoothing) else 1))
else: else:
self.miniters = max(self.miniters, dn) self.miniters = max(self.miniters, dn)
@ -1452,7 +1463,7 @@ class tqdm(Comparable):
elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0, elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0,
ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii, ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii,
unit=self.unit, unit_scale=self.unit_scale, unit=self.unit, unit_scale=self.unit_scale,
rate=self.avg_dn / self.avg_dt if self.avg_dt else None, rate=self._ema_dn() / self._ema_dt() if self._ema_dt() else None,
bar_format=self.bar_format, postfix=self.postfix, bar_format=self.bar_format, postfix=self.postfix,
unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour) unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour)