From 299789b910ca845c429998db0beb3f20ab5c7d15 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 5 Mar 2021 14:13:28 +0000 Subject: [PATCH] tests: test kwargs for `keras`, `dask` --- tests/tests_dask.py | 3 +- tests/tests_keras.py | 121 ++++++++++++++++++++----------------------- 2 files changed, 57 insertions(+), 67 deletions(-) diff --git a/tests/tests_dask.py b/tests/tests_dask.py index 2834abd6..8bf4b64f 100644 --- a/tests/tests_dask.py +++ b/tests/tests_dask.py @@ -13,7 +13,8 @@ def test_dask(capsys): dask = importorskip('dask') schedule = [dask.delayed(sleep)(i / 10) for i in range(5)] - with ProgressBar(): + with ProgressBar(desc="computing"): dask.compute(schedule) _, err = capsys.readouterr() + assert "computing: " in err assert '5/5' in err diff --git a/tests/tests_keras.py b/tests/tests_keras.py index 6d044207..b26cdbb7 100644 --- a/tests/tests_keras.py +++ b/tests/tests_keras.py @@ -1,14 +1,12 @@ from __future__ import division -from tqdm import tqdm - -from .tests_tqdm import StringIO, closing, importorskip, mark +from .tests_tqdm import importorskip, mark pytestmark = mark.slow @mark.filterwarnings("ignore:.*:DeprecationWarning") -def test_keras(): +def test_keras(capsys): """Test tqdm.keras.TqdmCallback""" TqdmCallback = importorskip('tqdm.keras').TqdmCallback np = importorskip('numpy') @@ -27,67 +25,58 @@ def test_keras(): batches = len(x) / batch_size epochs = 5 - with closing(StringIO()) as our_file: + # just epoch (no batch) progress + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[ + TqdmCallback( + epochs, + desc="training", + data_size=len(x), + batch_size=batch_size, + verbose=0, + )], + ) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) not in res - class Tqdm(tqdm): - """redirected I/O class""" - def __init__(self, *a, **k): - k.setdefault("file", our_file) - super(Tqdm, self).__init__(*a, **k) + # full (epoch and batch) progress + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[ + TqdmCallback( + epochs, + desc="training", + data_size=len(x), + batch_size=batch_size, + verbose=2, + )], + ) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res - # just epoch (no batch) progress - model.fit( - x, - x, - epochs=epochs, - batch_size=batch_size, - verbose=False, - callbacks=[ - TqdmCallback( - epochs, - data_size=len(x), - batch_size=batch_size, - verbose=0, - tqdm_class=Tqdm, - )], - ) - res = our_file.getvalue() - assert "{epochs}/{epochs}".format(epochs=epochs) in res - assert "{batches}/{batches}".format(batches=batches) not in res - - # full (epoch and batch) progress - our_file.seek(0) - our_file.truncate() - model.fit( - x, - x, - epochs=epochs, - batch_size=batch_size, - verbose=False, - callbacks=[ - TqdmCallback( - epochs, - data_size=len(x), - batch_size=batch_size, - verbose=2, - tqdm_class=Tqdm, - )], - ) - res = our_file.getvalue() - assert "{epochs}/{epochs}".format(epochs=epochs) in res - assert "{batches}/{batches}".format(batches=batches) in res - - # auto-detect epochs and batches - our_file.seek(0) - our_file.truncate() - model.fit( - x, - x, - epochs=epochs, - batch_size=batch_size, - verbose=False, - callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)], - ) - res = our_file.getvalue() - assert "{epochs}/{epochs}".format(epochs=epochs) in res - assert "{batches}/{batches}".format(batches=batches) in res + # auto-detect epochs and batches + model.fit( + x, + x, + epochs=epochs, + batch_size=batch_size, + verbose=False, + callbacks=[TqdmCallback(desc="training", verbose=2)], + ) + _, res = capsys.readouterr() + assert "training: " in res + assert "{epochs}/{epochs}".format(epochs=epochs) in res + assert "{batches}/{batches}".format(batches=batches) in res