diff --git a/tests/tests_tqdm.py b/tests/tests_tqdm.py index f07bdf56..91b65ade 100644 --- a/tests/tests_tqdm.py +++ b/tests/tests_tqdm.py @@ -15,7 +15,7 @@ from warnings import catch_warnings, simplefilter from tqdm import tqdm from tqdm import trange from tqdm import TqdmDeprecationWarning, TqdmWarning -from tqdm.std import Bar +from tqdm.std import Bar, EMA from tqdm.contrib import DummyTqdmFile try: @@ -1004,6 +1004,16 @@ def test_close(): t.close() +def test_ema(): + """Test exponential weighted average""" + ema = EMA(0.01) + assert round(ema(10), 2) == 10 + assert round(ema(1), 2) == 5.48 + assert round(ema(), 2) == 5.48 + assert round(ema(1), 2) == 3.97 + assert round(ema(1), 2) == 3.22 + + def test_smoothing(): """Test exponential weighted average smoothing""" timer = DiscreteTimer()