update tests

This commit is contained in:
Casper da Costa-Luis 2019-12-19 21:52:48 +00:00
parent e5863ffa56
commit a9e9e665f0
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
1 changed files with 16 additions and 1 deletions

View File

@ -1,5 +1,4 @@
from tqdm import tqdm
from tqdm.keras import TqdmCallback
from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, closing
@ -7,6 +6,7 @@ from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, closin
def test_keras():
"""Test tqdm.keras.TqdmCallback"""
try:
from tqdm.keras import TqdmCallback
import numpy as np
import keras as K
except ImportError:
@ -74,3 +74,18 @@ def test_keras():
res = our_file.getvalue()
assert res.count("100%") >= epochs + 1
assert "{epochs}/{epochs}".format(epochs=epochs) in res
# auto-detect epochs and batches
our_file.seek(0)
our_file.truncate()
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)],
)
res = our_file.getvalue()
assert res.count("100%") >= epochs + 1
assert "{epochs}/{epochs}".format(epochs=epochs) in res