mirror of https://github.com/tqdm/tqdm.git
keras: fix resume from `initial_epoch`
- set initial epochs - update tests
This commit is contained in:
parent
3f899b9f77
commit
74ec622661
|
@ -38,9 +38,7 @@ def test_keras(capsys):
|
||||||
desc="training",
|
desc="training",
|
||||||
data_size=len(x),
|
data_size=len(x),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
verbose=0,
|
verbose=0)])
|
||||||
)],
|
|
||||||
)
|
|
||||||
_, res = capsys.readouterr()
|
_, res = capsys.readouterr()
|
||||||
assert "training: " in res
|
assert "training: " in res
|
||||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||||
|
@ -59,9 +57,7 @@ def test_keras(capsys):
|
||||||
desc="training",
|
desc="training",
|
||||||
data_size=len(x),
|
data_size=len(x),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
verbose=2,
|
verbose=2)])
|
||||||
)],
|
|
||||||
)
|
|
||||||
_, res = capsys.readouterr()
|
_, res = capsys.readouterr()
|
||||||
assert "training: " in res
|
assert "training: " in res
|
||||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||||
|
@ -74,8 +70,7 @@ def test_keras(capsys):
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
callbacks=[TqdmCallback(desc="training", verbose=2)],
|
callbacks=[TqdmCallback(desc="training", verbose=2)])
|
||||||
)
|
|
||||||
_, res = capsys.readouterr()
|
_, res = capsys.readouterr()
|
||||||
assert "training: " in res
|
assert "training: " in res
|
||||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||||
|
@ -90,12 +85,9 @@ def test_keras(capsys):
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
callbacks=[TqdmCallback(
|
callbacks=[TqdmCallback(desc="training", verbose=0,
|
||||||
desc="training",
|
miniters=1, mininterval=0, maxinterval=0)])
|
||||||
verbose=2
|
|
||||||
)],
|
|
||||||
)
|
|
||||||
_, res = capsys.readouterr()
|
_, res = capsys.readouterr()
|
||||||
assert "training: " in res
|
assert "training: " in res
|
||||||
|
assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
|
||||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||||
assert "{batches}/{batches}".format(batches=batches) in res
|
|
||||||
|
|
|
@ -74,7 +74,8 @@ class TqdmCallback(keras.callbacks.Callback):
|
||||||
|
|
||||||
def on_epoch_begin(self, epoch, *_, **__):
|
def on_epoch_begin(self, epoch, *_, **__):
|
||||||
if self.epoch_bar.n < epoch:
|
if self.epoch_bar.n < epoch:
|
||||||
self.epoch_bar.update(epoch-self.epoch_bar.n)
|
ebar = self.epoch_bar
|
||||||
|
ebar.n = ebar.last_print_n = ebar.initial = epoch
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
params = self.params.get
|
params = self.params.get
|
||||||
total = params('samples', params(
|
total = params('samples', params(
|
||||||
|
|
Loading…
Reference in New Issue