mirror of https://github.com/tqdm/tqdm.git
fix keras display, expose initial bar(s) `**tqdm_kwargs`
This commit is contained in:
parent
410c136b49
commit
983a05b495
|
@ -28,7 +28,7 @@ class TqdmCallback(keras.callbacks.Callback):
|
|||
return callback
|
||||
|
||||
def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
|
||||
tqdm_class=tqdm_auto):
|
||||
tqdm_class=tqdm_auto, **tqdm_kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -43,9 +43,13 @@ class TqdmCallback(keras.callbacks.Callback):
|
|||
are given.
|
||||
tqdm_class : optional
|
||||
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
|
||||
tqdm_kwargs : optional
|
||||
Any other arguments used for initial bars.
|
||||
Instead, for passing arguments to all bars, create a custom
|
||||
`tqdm_class`.
|
||||
"""
|
||||
self.tqdm_class = tqdm_class
|
||||
self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
|
||||
self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs)
|
||||
self.on_epoch_end = self.bar2callback(self.epoch_bar)
|
||||
if data_size and batch_size:
|
||||
self.batches = batches = (data_size + batch_size - 1) // batch_size
|
||||
|
@ -54,7 +58,7 @@ class TqdmCallback(keras.callbacks.Callback):
|
|||
self.verbose = verbose
|
||||
if verbose == 1:
|
||||
self.batch_bar = tqdm_class(total=batches, unit='batch',
|
||||
leave=False)
|
||||
leave=False, **tqdm_kwargs)
|
||||
self.on_batch_end = self.bar2callback(
|
||||
self.batch_bar,
|
||||
pop=['batch', 'size'],
|
||||
|
|
Loading…
Reference in New Issue