diff --git a/tqdm/keras.py b/tqdm/keras.py index 54c99079..b54e7563 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -28,7 +28,7 @@ class TqdmCallback(keras.callbacks.Callback): return callback def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, - tqdm_class=tqdm_auto): + tqdm_class=tqdm_auto, **tqdm_kwargs): """ Parameters ---------- @@ -43,9 +43,13 @@ class TqdmCallback(keras.callbacks.Callback): are given. tqdm_class : optional `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. + tqdm_kwargs : optional + Any other arguments used for initial bars. + Instead, for passing arguments to all bars, create a custom + `tqdm_class`. """ self.tqdm_class = tqdm_class - self.epoch_bar = tqdm_class(total=epochs, unit='epoch') + self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs) self.on_epoch_end = self.bar2callback(self.epoch_bar) if data_size and batch_size: self.batches = batches = (data_size + batch_size - 1) // batch_size @@ -54,7 +58,7 @@ class TqdmCallback(keras.callbacks.Callback): self.verbose = verbose if verbose == 1: self.batch_bar = tqdm_class(total=batches, unit='batch', - leave=False) + leave=False, **tqdm_kwargs) self.on_batch_end = self.bar2callback( self.batch_bar, pop=['batch', 'size'],