fix: display wrong epoch on keras resume training

This commit is contained in:
Gonzalo Tixilima 2021-03-19 11:18:43 -05:00 committed by Casper da Costa-Luis
parent d9372a2d47
commit 3f899b9f77
No known key found for this signature in database
GPG Key ID: F5126E5FBD2512AD
2 changed files with 23 additions and 2 deletions

View File

@ -80,3 +80,22 @@ def test_keras(capsys):
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res
# continue training (start from epoch != 0)
initial_epoch = 3
model.fit(
x,
x,
initial_epoch=initial_epoch,
epochs=epochs,
batch_size=batch_size,
verbose=False,
callbacks=[TqdmCallback(
desc="training",
verbose=2
)],
)
_, res = capsys.readouterr()
assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res

View File

@ -69,10 +69,12 @@ class TqdmCallback(keras.callbacks.Callback):
def on_train_begin(self, *_, **__):
params = self.params.get
auto_total = params('epochs', params('nb_epoch', None))
if auto_total is not None:
if auto_total is not None and auto_total != self.epoch_bar.total:
self.epoch_bar.reset(total=auto_total)
def on_epoch_begin(self, *_, **__):
def on_epoch_begin(self, epoch, *_, **__):
if self.epoch_bar.n < epoch:
self.epoch_bar.update(epoch-self.epoch_bar.n)
if self.verbose:
params = self.params.get
total = params('samples', params(