tests: test kwargs for `keras`, `dask`

This commit is contained in:
Casper da Costa-Luis 2021-03-05 14:13:28 +00:00
parent 2a82405da5
commit 299789b910
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
2 changed files with 57 additions and 67 deletions

View File

@ -13,7 +13,8 @@ def test_dask(capsys):
dask = importorskip('dask') dask = importorskip('dask')
schedule = [dask.delayed(sleep)(i / 10) for i in range(5)] schedule = [dask.delayed(sleep)(i / 10) for i in range(5)]
with ProgressBar(): with ProgressBar(desc="computing"):
dask.compute(schedule) dask.compute(schedule)
_, err = capsys.readouterr() _, err = capsys.readouterr()
assert "computing: " in err
assert '5/5' in err assert '5/5' in err

View File

@ -1,14 +1,12 @@
from __future__ import division from __future__ import division
from tqdm import tqdm from .tests_tqdm import importorskip, mark
from .tests_tqdm import StringIO, closing, importorskip, mark
pytestmark = mark.slow pytestmark = mark.slow
@mark.filterwarnings("ignore:.*:DeprecationWarning") @mark.filterwarnings("ignore:.*:DeprecationWarning")
def test_keras(): def test_keras(capsys):
"""Test tqdm.keras.TqdmCallback""" """Test tqdm.keras.TqdmCallback"""
TqdmCallback = importorskip('tqdm.keras').TqdmCallback TqdmCallback = importorskip('tqdm.keras').TqdmCallback
np = importorskip('numpy') np = importorskip('numpy')
@ -27,67 +25,58 @@ def test_keras():
batches = len(x) / batch_size batches = len(x) / batch_size
epochs = 5 epochs = 5
with closing(StringIO()) as our_file: # just epoch (no batch) progress
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=0,
)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) not in res
class Tqdm(tqdm): # full (epoch and batch) progress
"""redirected I/O class""" model.fit(
def __init__(self, *a, **k): x,
k.setdefault("file", our_file) x,
super(Tqdm, self).__init__(*a, **k) epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=2,
)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# just epoch (no batch) progress # auto-detect epochs and batches
model.fit( model.fit(
x, x,
x, x,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
verbose=False, verbose=False,
callbacks=[ callbacks=[TqdmCallback(desc="training", verbose=2)],
TqdmCallback( )
epochs, _, res = capsys.readouterr()
data_size=len(x), assert "training: " in res
batch_size=batch_size, assert "{epochs}/{epochs}".format(epochs=epochs) in res
verbose=0, assert "{batches}/{batches}".format(batches=batches) in res
tqdm_class=Tqdm,
)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) not in res
# full (epoch and batch) progress
our_file.seek(0)
our_file.truncate()
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
data_size=len(x),
batch_size=batch_size,
verbose=2,
tqdm_class=Tqdm,
)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# auto-detect epochs and batches
our_file.seek(0)
our_file.truncate()
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res