keras: use kwargs for all bars

This commit is contained in:
Casper da Costa-Luis 2020-11-05 18:31:02 +00:00
parent 983a05b495
commit 669002c45c
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
1 changed files with 6 additions and 5 deletions

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division
from .auto import tqdm as tqdm_auto
from copy import copy
from functools import partial
try:
import keras
except ImportError as e:
@ -44,12 +45,12 @@ class TqdmCallback(keras.callbacks.Callback):
tqdm_class : optional
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
tqdm_kwargs : optional
Any other arguments used for initial bars.
Instead, for passing arguments to all bars, create a custom
`tqdm_class`.
Any other arguments used for all bars.
"""
if tqdm_kwargs:
tqdm_class = partial(tqdm_class, **tqdm_kwargs)
self.tqdm_class = tqdm_class
self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs)
self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
self.on_epoch_end = self.bar2callback(self.epoch_bar)
if data_size and batch_size:
self.batches = batches = (data_size + batch_size - 1) // batch_size
@ -58,7 +59,7 @@ class TqdmCallback(keras.callbacks.Callback):
self.verbose = verbose
if verbose == 1:
self.batch_bar = tqdm_class(total=batches, unit='batch',
leave=False, **tqdm_kwargs)
leave=False)
self.on_batch_end = self.bar2callback(
self.batch_bar,
pop=['batch', 'size'],