diff --git a/tests/tests_tqdm.py b/tests/tests_tqdm.py index fe49053e..f07bdf56 100644 --- a/tests/tests_tqdm.py +++ b/tests/tests_tqdm.py @@ -754,28 +754,29 @@ def test_smoothed_dynamic_min_iters(): timer = DiscreteTimer() with closing(StringIO()) as our_file: - with tqdm(total=100, file=our_file, miniters=None, mininterval=0, + with tqdm(total=100, file=our_file, miniters=None, mininterval=1, smoothing=0.5, maxinterval=0) as t: cpu_timify(t, timer) # Increase 10 iterations at once + timer.sleep(1) t.update(10) # The next iterations should be partially skipped for _ in _range(2): + timer.sleep(1) t.update(4) for _ in _range(20): + timer.sleep(1) t.update() - out = our_file.getvalue() assert t.dynamic_miniters + out = our_file.getvalue() assert ' 0%| | 0/100 [00:00<' in out - assert '10%' in out - assert '14%' not in out - assert '18%' in out - assert '20%' not in out + assert '20%' in out + assert '23%' not in out assert '25%' in out - assert '30%' not in out - assert '32%' in out + assert '26%' not in out + assert '28%' in out def test_smoothed_dynamic_min_iters_with_min_interval(): diff --git a/tqdm/std.py b/tqdm/std.py index c910a7b2..f4e11fdc 100644 --- a/tqdm/std.py +++ b/tqdm/std.py @@ -213,6 +213,37 @@ class Bar(object): return self.colour + res + self.COLOUR_RESET if self.colour else res +class EMA(object): + """ + Exponential moving average: smoothing to give progressively lower + weights to older values. + + Parameters + ---------- + smoothing : float, optional + Smoothing factor in range [0, 1], [default: 0.3]. + Increase to give more weight to recent values. + Ranges from 0 (yields old value) to 1 (yields new value). + """ + def __init__(self, smoothing=0.3): + self.alpha = smoothing + self.last = 0 + self.calls = 0 + + def __call__(self, x=None): + """ + Parameters + ---------- + x : float + New value to include in EMA. + """ + beta = 1 - self.alpha + if x is not None: + self.last = self.alpha * x + beta * self.last + self.calls += 1 + return self.last / (1 - beta ** self.calls) if self.calls else self.last + + class tqdm(Comparable): """ Decorate an iterable object, returning an iterator which acts exactly @@ -295,26 +326,6 @@ class tqdm(Comparable): n = str(n) return f if len(f) < len(n) else n - def ema(self, x, mu=0, alpha=0.3): - """ - Exponential moving average: smoothing to give progressively lower - weights to older values. - - Parameters - ---------- - x : float - New value to include in EMA. - mu : float, optional - Previous EMA value. - alpha : float, optional - Smoothing factor in range [0, 1], [default: 0.3]. - Increase to give more weight to recent values. - Ranges from 0 (yields mu) to 1 (yields x). - """ - beta = 1 - alpha - res = alpha * x + beta * mu - return res / (1 - beta ** self.n) if self.n else res - @staticmethod def status_printer(file): """ @@ -1048,8 +1059,9 @@ class tqdm(Comparable): self.gui = gui self.dynamic_ncols = dynamic_ncols self.smoothing = smoothing - self.avg_dn = 0 - self.avg_dt = 0 + self._ema_dn = EMA(smoothing) + self._ema_dt = EMA(smoothing) + self._ema_miniters = EMA(smoothing) self.bar_format = bar_format self.postfix = None self.colour = colour @@ -1144,6 +1156,9 @@ class tqdm(Comparable): last_print_n = self.last_print_n n = self.n smoothing = self.smoothing + _ema_dn = self._ema_dn + _ema_dt = self._ema_dt + _ema_miniters = self._ema_miniters time = self._time try: @@ -1159,12 +1174,11 @@ class tqdm(Comparable): if dt >= mininterval: cur_t = time() dn = n - last_print_n - self.n = n # EMA (not just overall average) if smoothing and dt and dn: - self.avg_dn = self.ema(dn, self.avg_dn, smoothing) - self.avg_dt = self.ema(dt, self.avg_dt, smoothing) - + _ema_dn(dn) + _ema_dt(dt) + self.n = n self.refresh(lock_args=self.lock_args) # If no `miniters` was specified, adjust automatically @@ -1181,10 +1195,9 @@ class tqdm(Comparable): elif smoothing: # EMA-weight miniters to converge # towards the timeframe of mininterval - miniters = self.ema( + miniters = _ema_miniters( dn * (mininterval / dt - if mininterval and dt else 1), - miniters, smoothing) + if mininterval and dt else 1)) else: # Maximum nb of iterations between 2 prints miniters = max(miniters, dn) @@ -1242,11 +1255,9 @@ class tqdm(Comparable): cur_t = self._time() dn = self.n - self.last_print_n # >= n # elapsed = cur_t - self.start_t - # EMA (not just overall average) if self.smoothing and dt and dn: - self.avg_dn = self.ema(dn, self.avg_dn, self.smoothing) - self.avg_dt = self.ema(dt, self.avg_dt, self.smoothing) - + self._ema_dn(dn) + self._ema_dt(dt) self.refresh(lock_args=self.lock_args) # If no `miniters` was specified, adjust automatically to the @@ -1261,9 +1272,9 @@ class tqdm(Comparable): else: self.miniters = dn * self.maxinterval / dt elif self.smoothing: - self.miniters = self.ema( + self.miniters = self._ema_miniters( dn * (self.mininterval / dt if self.mininterval and dt - else 1), self.miniters, self.smoothing) + else 1)) else: self.miniters = max(self.miniters, dn) @@ -1452,7 +1463,7 @@ class tqdm(Comparable): elapsed=self._time() - self.start_t if hasattr(self, 'start_t') else 0, ncols=ncols, nrows=nrows, prefix=self.desc, ascii=self.ascii, unit=self.unit, unit_scale=self.unit_scale, - rate=self.avg_dn / self.avg_dt if self.avg_dt else None, + rate=self._ema_dn() / self._ema_dt() if self._ema_dt() else None, bar_format=self.bar_format, postfix=self.postfix, unit_divisor=self.unit_divisor, initial=self.initial, colour=self.colour)