keras: fix resume from `initial_epoch`

- set initial epochs
- update tests
This commit is contained in:
Casper da Costa-Luis 2021-05-03 00:52:05 +01:00
parent 3f899b9f77
commit 74ec622661
No known key found for this signature in database
GPG Key ID: F5126E5FBD2512AD
2 changed files with 8 additions and 15 deletions

View File

@ -38,9 +38,7 @@ def test_keras(capsys):
desc="training", desc="training",
data_size=len(x), data_size=len(x),
batch_size=batch_size, batch_size=batch_size,
verbose=0, verbose=0)])
)],
)
_, res = capsys.readouterr() _, res = capsys.readouterr()
assert "training: " in res assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{epochs}/{epochs}".format(epochs=epochs) in res
@ -59,9 +57,7 @@ def test_keras(capsys):
desc="training", desc="training",
data_size=len(x), data_size=len(x),
batch_size=batch_size, batch_size=batch_size,
verbose=2, verbose=2)])
)],
)
_, res = capsys.readouterr() _, res = capsys.readouterr()
assert "training: " in res assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{epochs}/{epochs}".format(epochs=epochs) in res
@ -74,8 +70,7 @@ def test_keras(capsys):
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
verbose=False, verbose=False,
callbacks=[TqdmCallback(desc="training", verbose=2)], callbacks=[TqdmCallback(desc="training", verbose=2)])
)
_, res = capsys.readouterr() _, res = capsys.readouterr()
assert "training: " in res assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{epochs}/{epochs}".format(epochs=epochs) in res
@ -90,12 +85,9 @@ def test_keras(capsys):
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
verbose=False, verbose=False,
callbacks=[TqdmCallback( callbacks=[TqdmCallback(desc="training", verbose=0,
desc="training", miniters=1, mininterval=0, maxinterval=0)])
verbose=2
)],
)
_, res = capsys.readouterr() _, res = capsys.readouterr()
assert "training: " in res assert "training: " in res
assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
assert "{epochs}/{epochs}".format(epochs=epochs) in res assert "{epochs}/{epochs}".format(epochs=epochs) in res
assert "{batches}/{batches}".format(batches=batches) in res

View File

@ -74,7 +74,8 @@ class TqdmCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, *_, **__): def on_epoch_begin(self, epoch, *_, **__):
if self.epoch_bar.n < epoch: if self.epoch_bar.n < epoch:
self.epoch_bar.update(epoch-self.epoch_bar.n) ebar = self.epoch_bar
ebar.n = ebar.last_print_n = ebar.initial = epoch
if self.verbose: if self.verbose:
params = self.params.get params = self.params.get
total = params('samples', params( total = params('samples', params(