better EMA accuracy for early iterations

- fixes #1101
This commit is contained in:
Casper da Costa-Luis 2020-12-24 14:20:39 +00:00
parent e225d9a894
commit e9bf3cde2b
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
2 changed files with 35 additions and 43 deletions

View File

@ -60,7 +60,7 @@ if os.name == 'nt':
# List of control characters
CTRLCHR = [r'\r', r'\n', r'\x1b\[A'] # Need to escape [ for regex
# Regular expressions compilation
RE_rate = re.compile(r'(\d+\.\d+)it/s')
RE_rate = re.compile(r'[^\d](\d[.\d]+)it/s')
RE_ctrlchr = re.compile("(%s)" % '|'.join(CTRLCHR)) # Match control chars
RE_ctrlchr_excl = re.compile('|'.join(CTRLCHR)) # Match and exclude ctrl chars
RE_pos = re.compile(r'([\r\n]+((pos\d+) bar:\s+\d+%|\s{3,6})?[^\r\n]*)')
@ -1017,7 +1017,6 @@ def test_smoothing():
assert '| 3/3 ' in our_file.getvalue()
# -- Test smoothing
# Compile the regex to find the rate
# 1st case: no smoothing (only use average)
with closing(StringIO()) as our_file2:
with closing(StringIO()) as our_file:
@ -1074,11 +1073,11 @@ def test_smoothing():
# 3rd case: use medium smoothing
with closing(StringIO()) as our_file2:
with closing(StringIO()) as our_file:
t = tqdm(_range(3), file=our_file2, smoothing=0.5, leave=True,
t = tqdm(_range(3), file=our_file2, smoothing=0.8, leave=True,
miniters=1, mininterval=0)
cpu_timify(t, timer)
t2 = tqdm(_range(3), file=our_file, smoothing=0.5, leave=True,
t2 = tqdm(_range(3), file=our_file, smoothing=0.8, leave=True,
miniters=1, mininterval=0)
cpu_timify(t2, timer)

View File

@ -295,8 +295,7 @@ class tqdm(Comparable):
n = str(n)
return f if len(f) < len(n) else n
@staticmethod
def ema(x, mu=None, alpha=0.3):
def ema(self, x, mu=0, alpha=0.3):
"""
Exponential moving average: smoothing to give progressively lower
weights to older values.
@ -312,7 +311,9 @@ class tqdm(Comparable):
Increase to give more weight to recent values.
Ranges from 0 (yields mu) to 1 (yields x).
"""
return x if mu is None else (alpha * x) + (1 - alpha) * mu
beta = 1 - alpha
res = alpha * x + beta * mu
return res / (1 - beta ** self.n) if self.n else res
@staticmethod
def status_printer(file):
@ -1047,7 +1048,7 @@ class tqdm(Comparable):
self.gui = gui
self.dynamic_ncols = dynamic_ncols
self.smoothing = smoothing
self.avg_time = None
self.avg_time = 0
self.bar_format = bar_format
self.postfix = None
self.colour = colour
@ -1142,7 +1143,6 @@ class tqdm(Comparable):
last_print_n = self.last_print_n
n = self.n
smoothing = self.smoothing
avg_time = self.avg_time
time = self._time
try:
@ -1154,40 +1154,38 @@ class tqdm(Comparable):
# check counter first to avoid calls to time()
if n - last_print_n >= self.miniters:
miniters = self.miniters # watch monitoring thread changes
delta_t = time() - last_print_t
if delta_t >= mininterval:
dt = time() - last_print_t
if dt >= mininterval:
cur_t = time()
delta_it = n - last_print_n
# EMA (not just overall average)
if smoothing and delta_t and delta_it:
rate = delta_t / delta_it
avg_time = self.ema(rate, avg_time, smoothing)
self.avg_time = avg_time
dn = n - last_print_n
self.n = n
# EMA (not just overall average)
if smoothing and dt and dn:
self.avg_time = self.ema(dt / dn, self.avg_time, smoothing)
self.refresh(lock_args=self.lock_args)
# If no `miniters` was specified, adjust automatically
# to the max iteration rate seen so far between 2 prints
if dynamic_miniters:
if maxinterval and delta_t >= maxinterval:
if maxinterval and dt >= maxinterval:
# Adjust miniters to time interval by rule of 3
if mininterval:
# Set miniters to correspond to mininterval
miniters = delta_it * mininterval / delta_t
miniters = dn * mininterval / dt
else:
# Set miniters to correspond to maxinterval
miniters = delta_it * maxinterval / delta_t
miniters = dn * maxinterval / dt
elif smoothing:
# EMA-weight miniters to converge
# towards the timeframe of mininterval
rate = delta_it
if mininterval and delta_t:
rate *= mininterval / delta_t
miniters = self.ema(rate, miniters, smoothing)
miniters = self.ema(
dn * (mininterval / dt
if mininterval and dt else 1),
miniters, smoothing)
else:
# Maximum nb of iterations between 2 prints
miniters = max(miniters, delta_it)
miniters = max(miniters, dn)
# Store old values for next call
self.n = self.last_print_n = last_print_n = n
@ -1237,15 +1235,14 @@ class tqdm(Comparable):
# check counter first to reduce calls to time()
if self.n - self.last_print_n >= self.miniters:
delta_t = self._time() - self.last_print_t
if delta_t >= self.mininterval:
dt = self._time() - self.last_print_t
if dt >= self.mininterval:
cur_t = self._time()
delta_it = self.n - self.last_print_n # >= n
dn = self.n - self.last_print_n # >= n
# elapsed = cur_t - self.start_t
# EMA (not just overall average)
if self.smoothing and delta_t and delta_it:
rate = delta_t / delta_it
self.avg_time = self.ema(rate, self.avg_time, self.smoothing)
if self.smoothing and dt and dn:
self.avg_time = self.ema(dt / dn, self.avg_time, self.smoothing)
self.refresh(lock_args=self.lock_args)
@ -1255,21 +1252,17 @@ class tqdm(Comparable):
# calls to `tqdm.update()` will only cause an update after
# at least 5 more iterations.
if self.dynamic_miniters:
if self.maxinterval and delta_t >= self.maxinterval:
if self.maxinterval and dt >= self.maxinterval:
if self.mininterval:
self.miniters = delta_it * self.mininterval \
/ delta_t
self.miniters = dn * self.mininterval / dt
else:
self.miniters = delta_it * self.maxinterval \
/ delta_t
self.miniters = dn * self.maxinterval / dt
elif self.smoothing:
self.miniters = self.smoothing * delta_it * \
(self.mininterval / delta_t
if self.mininterval and delta_t
else 1) + \
(1 - self.smoothing) * self.miniters
self.miniters = self.ema(
dn * (self.mininterval / dt if self.mininterval and dt
else 1), self.miniters, self.smoothing)
else:
self.miniters = max(self.miniters, delta_it)
self.miniters = max(self.miniters, dn)
# Store old values for next call
self.last_print_n = self.n