tests: test kwargs for `keras`, `dask`

This commit is contained in:
Casper da Costa-Luis 2021-03-05 14:13:28 +00:00
parent 2a82405da5
commit 299789b910
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
2 changed files with 57 additions and 67 deletions

View File

@ -13,7 +13,8 @@ def test_dask(capsys):
dask = importorskip('dask')
schedule = [dask.delayed(sleep)(i / 10) for i in range(5)]
with ProgressBar():
with ProgressBar(desc="computing"):
dask.compute(schedule)
_, err = capsys.readouterr()
assert "computing: " in err
assert '5/5' in err

View File

@ -1,14 +1,12 @@
from __future__ import division
from tqdm import tqdm
from .tests_tqdm import StringIO, closing, importorskip, mark
from .tests_tqdm import importorskip, mark
pytestmark = mark.slow
@mark.filterwarnings("ignore:.*:DeprecationWarning")
def test_keras():
def test_keras(capsys):
"""Test tqdm.keras.TqdmCallback"""
TqdmCallback = importorskip('tqdm.keras').TqdmCallback
np = importorskip('numpy')
@ -27,67 +25,58 @@ def test_keras():
batches = len(x) / batch_size
epochs = 5
with closing(StringIO()) as our_file:
# just epoch (no batch) progress
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=0,
)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) not in res
class Tqdm(tqdm):
"""redirected I/O class"""
def __init__(self, *a, **k):
k.setdefault("file", our_file)
super(Tqdm, self).__init__(*a, **k)
# full (epoch and batch) progress
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
desc="training",
data_size=len(x),
batch_size=batch_size,
verbose=2,
)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# just epoch (no batch) progress
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
data_size=len(x),
batch_size=batch_size,
verbose=0,
tqdm_class=Tqdm,
)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) not in res
# full (epoch and batch) progress
our_file.seek(0)
our_file.truncate()
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[
TqdmCallback(
epochs,
data_size=len(x),
batch_size=batch_size,
verbose=2,
tqdm_class=Tqdm,
)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# auto-detect epochs and batches
our_file.seek(0)
our_file.truncate()
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)],
)
res = our_file.getvalue()
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# auto-detect epochs and batches
model.fit(
x,
x,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(desc="training", verbose=2)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res