fix keras display, expose initial bar(s) `**tqdm_kwargs`

This commit is contained in:
Casper da Costa-Luis 2020-10-25 01:47:41 +01:00
parent 410c136b49
commit 983a05b495
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
1 changed files with 7 additions and 3 deletions

View File

@ -28,7 +28,7 @@ class TqdmCallback(keras.callbacks.Callback):
return callback return callback
def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
tqdm_class=tqdm_auto): tqdm_class=tqdm_auto, **tqdm_kwargs):
""" """
Parameters Parameters
---------- ----------
@ -43,9 +43,13 @@ class TqdmCallback(keras.callbacks.Callback):
are given. are given.
tqdm_class : optional tqdm_class : optional
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
tqdm_kwargs : optional
Any other arguments used for initial bars.
Instead, for passing arguments to all bars, create a custom
`tqdm_class`.
""" """
self.tqdm_class = tqdm_class self.tqdm_class = tqdm_class
self.epoch_bar = tqdm_class(total=epochs, unit='epoch') self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs)
self.on_epoch_end = self.bar2callback(self.epoch_bar) self.on_epoch_end = self.bar2callback(self.epoch_bar)
if data_size and batch_size: if data_size and batch_size:
self.batches = batches = (data_size + batch_size - 1) // batch_size self.batches = batches = (data_size + batch_size - 1) // batch_size
@ -54,7 +58,7 @@ class TqdmCallback(keras.callbacks.Callback):
self.verbose = verbose self.verbose = verbose
if verbose == 1: if verbose == 1:
self.batch_bar = tqdm_class(total=batches, unit='batch', self.batch_bar = tqdm_class(total=batches, unit='batch',
leave=False) leave=False, **tqdm_kwargs)
self.on_batch_end = self.bar2callback( self.on_batch_end = self.bar2callback(
self.batch_bar, self.batch_bar,
pop=['batch', 'size'], pop=['batch', 'size'],