mirror of https://github.com/tqdm/tqdm.git
fix: display wrong epoch on keras resume training
This commit is contained in:
parent
d9372a2d47
commit
3f899b9f77
|
@ -80,3 +80,22 @@ def test_keras(capsys):
|
|||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||
assert "{batches}/{batches}".format(batches=batches) in res
|
||||
|
||||
# continue training (start from epoch != 0)
|
||||
initial_epoch = 3
|
||||
model.fit(
|
||||
x,
|
||||
x,
|
||||
initial_epoch=initial_epoch,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
verbose=False,
|
||||
callbacks=[TqdmCallback(
|
||||
desc="training",
|
||||
verbose=2
|
||||
)],
|
||||
)
|
||||
_, res = capsys.readouterr()
|
||||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||
assert "{batches}/{batches}".format(batches=batches) in res
|
||||
|
|
|
@ -69,10 +69,12 @@ class TqdmCallback(keras.callbacks.Callback):
|
|||
def on_train_begin(self, *_, **__):
|
||||
params = self.params.get
|
||||
auto_total = params('epochs', params('nb_epoch', None))
|
||||
if auto_total is not None:
|
||||
if auto_total is not None and auto_total != self.epoch_bar.total:
|
||||
self.epoch_bar.reset(total=auto_total)
|
||||
|
||||
def on_epoch_begin(self, *_, **__):
|
||||
def on_epoch_begin(self, epoch, *_, **__):
|
||||
if self.epoch_bar.n < epoch:
|
||||
self.epoch_bar.update(epoch-self.epoch_bar.n)
|
||||
if self.verbose:
|
||||
params = self.params.get
|
||||
total = params('samples', params(
|
||||
|
|
Loading…
Reference in New Issue