From e9bf3cde2bcbef580abe71670c032b9a7e88e022 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Thu, 24 Dec 2020 14:20:39 +0000 Subject: [PATCH] better EMA accuracy for early iterations - fixes #1101 --- tests/tests_tqdm.py | 7 ++--- tqdm/std.py | 71 ++++++++++++++++++++------------------------- 2 files changed, 35 insertions(+), 43 deletions(-) diff --git a/tests/tests_tqdm.py b/tests/tests_tqdm.py index c51e2533..9de3c847 100644 --- a/tests/tests_tqdm.py +++ b/tests/tests_tqdm.py @@ -60,7 +60,7 @@ if os.name == 'nt': # List of control characters CTRLCHR = [r'\r', r'\n', r'\x1b\[A'] # Need to escape [ for regex # Regular expressions compilation -RE_rate = re.compile(r'(\d+\.\d+)it/s') +RE_rate = re.compile(r'[^\d](\d[.\d]+)it/s') RE_ctrlchr = re.compile("(%s)" % '|'.join(CTRLCHR)) # Match control chars RE_ctrlchr_excl = re.compile('|'.join(CTRLCHR)) # Match and exclude ctrl chars RE_pos = re.compile(r'([\r\n]+((pos\d+) bar:\s+\d+%|\s{3,6})?[^\r\n]*)') @@ -1017,7 +1017,6 @@ def test_smoothing(): assert '| 3/3 ' in our_file.getvalue() # -- Test smoothing - # Compile the regex to find the rate # 1st case: no smoothing (only use average) with closing(StringIO()) as our_file2: with closing(StringIO()) as our_file: @@ -1074,11 +1073,11 @@ def test_smoothing(): # 3rd case: use medium smoothing with closing(StringIO()) as our_file2: with closing(StringIO()) as our_file: - t = tqdm(_range(3), file=our_file2, smoothing=0.5, leave=True, + t = tqdm(_range(3), file=our_file2, smoothing=0.8, leave=True, miniters=1, mininterval=0) cpu_timify(t, timer) - t2 = tqdm(_range(3), file=our_file, smoothing=0.5, leave=True, + t2 = tqdm(_range(3), file=our_file, smoothing=0.8, leave=True, miniters=1, mininterval=0) cpu_timify(t2, timer) diff --git a/tqdm/std.py b/tqdm/std.py index d625e396..78b607d3 100644 --- a/tqdm/std.py +++ b/tqdm/std.py @@ -295,8 +295,7 @@ class tqdm(Comparable): n = str(n) return f if len(f) < len(n) else n - @staticmethod - def ema(x, mu=None, alpha=0.3): + def ema(self, x, mu=0, alpha=0.3): """ Exponential moving average: smoothing to give progressively lower weights to older values. @@ -312,7 +311,9 @@ class tqdm(Comparable): Increase to give more weight to recent values. Ranges from 0 (yields mu) to 1 (yields x). """ - return x if mu is None else (alpha * x) + (1 - alpha) * mu + beta = 1 - alpha + res = alpha * x + beta * mu + return res / (1 - beta ** self.n) if self.n else res @staticmethod def status_printer(file): @@ -1047,7 +1048,7 @@ class tqdm(Comparable): self.gui = gui self.dynamic_ncols = dynamic_ncols self.smoothing = smoothing - self.avg_time = None + self.avg_time = 0 self.bar_format = bar_format self.postfix = None self.colour = colour @@ -1142,7 +1143,6 @@ class tqdm(Comparable): last_print_n = self.last_print_n n = self.n smoothing = self.smoothing - avg_time = self.avg_time time = self._time try: @@ -1154,40 +1154,38 @@ class tqdm(Comparable): # check counter first to avoid calls to time() if n - last_print_n >= self.miniters: miniters = self.miniters # watch monitoring thread changes - delta_t = time() - last_print_t - if delta_t >= mininterval: + dt = time() - last_print_t + if dt >= mininterval: cur_t = time() - delta_it = n - last_print_n - # EMA (not just overall average) - if smoothing and delta_t and delta_it: - rate = delta_t / delta_it - avg_time = self.ema(rate, avg_time, smoothing) - self.avg_time = avg_time - + dn = n - last_print_n self.n = n + # EMA (not just overall average) + if smoothing and dt and dn: + self.avg_time = self.ema(dt / dn, self.avg_time, smoothing) + self.refresh(lock_args=self.lock_args) # If no `miniters` was specified, adjust automatically # to the max iteration rate seen so far between 2 prints if dynamic_miniters: - if maxinterval and delta_t >= maxinterval: + if maxinterval and dt >= maxinterval: # Adjust miniters to time interval by rule of 3 if mininterval: # Set miniters to correspond to mininterval - miniters = delta_it * mininterval / delta_t + miniters = dn * mininterval / dt else: # Set miniters to correspond to maxinterval - miniters = delta_it * maxinterval / delta_t + miniters = dn * maxinterval / dt elif smoothing: # EMA-weight miniters to converge # towards the timeframe of mininterval - rate = delta_it - if mininterval and delta_t: - rate *= mininterval / delta_t - miniters = self.ema(rate, miniters, smoothing) + miniters = self.ema( + dn * (mininterval / dt + if mininterval and dt else 1), + miniters, smoothing) else: # Maximum nb of iterations between 2 prints - miniters = max(miniters, delta_it) + miniters = max(miniters, dn) # Store old values for next call self.n = self.last_print_n = last_print_n = n @@ -1237,15 +1235,14 @@ class tqdm(Comparable): # check counter first to reduce calls to time() if self.n - self.last_print_n >= self.miniters: - delta_t = self._time() - self.last_print_t - if delta_t >= self.mininterval: + dt = self._time() - self.last_print_t + if dt >= self.mininterval: cur_t = self._time() - delta_it = self.n - self.last_print_n # >= n + dn = self.n - self.last_print_n # >= n # elapsed = cur_t - self.start_t # EMA (not just overall average) - if self.smoothing and delta_t and delta_it: - rate = delta_t / delta_it - self.avg_time = self.ema(rate, self.avg_time, self.smoothing) + if self.smoothing and dt and dn: + self.avg_time = self.ema(dt / dn, self.avg_time, self.smoothing) self.refresh(lock_args=self.lock_args) @@ -1255,21 +1252,17 @@ class tqdm(Comparable): # calls to `tqdm.update()` will only cause an update after # at least 5 more iterations. if self.dynamic_miniters: - if self.maxinterval and delta_t >= self.maxinterval: + if self.maxinterval and dt >= self.maxinterval: if self.mininterval: - self.miniters = delta_it * self.mininterval \ - / delta_t + self.miniters = dn * self.mininterval / dt else: - self.miniters = delta_it * self.maxinterval \ - / delta_t + self.miniters = dn * self.maxinterval / dt elif self.smoothing: - self.miniters = self.smoothing * delta_it * \ - (self.mininterval / delta_t - if self.mininterval and delta_t - else 1) + \ - (1 - self.smoothing) * self.miniters + self.miniters = self.ema( + dn * (self.mininterval / dt if self.mininterval and dt + else 1), self.miniters, self.smoothing) else: - self.miniters = max(self.miniters, delta_it) + self.miniters = max(self.miniters, dn) # Store old values for next call self.last_print_n = self.n