From a9e9e665f099a85db77b5d774e729cada46cd66f Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Thu, 19 Dec 2019 21:52:48 +0000 Subject: [PATCH] update tests --- tqdm/tests/tests_keras.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tqdm/tests/tests_keras.py b/tqdm/tests/tests_keras.py index f5d38d23..80a65bfe 100644 --- a/tqdm/tests/tests_keras.py +++ b/tqdm/tests/tests_keras.py @@ -1,5 +1,4 @@ from tqdm import tqdm -from tqdm.keras import TqdmCallback from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, closing @@ -7,6 +6,7 @@ from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, closin def test_keras(): """Test tqdm.keras.TqdmCallback""" try: + from tqdm.keras import TqdmCallback import numpy as np import keras as K except ImportError: @@ -74,3 +74,18 @@ def test_keras(): res = our_file.getvalue() assert res.count("100%") >= epochs + 1 assert "{epochs}/{epochs}".format(epochs=epochs) in res + + # auto-detect epochs and batches + our_file.seek(0) + our_file.truncate() + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)], + ) + res = our_file.getvalue() + assert res.count("100%") >= epochs + 1 + assert "{epochs}/{epochs}".format(epochs=epochs) in res