From 669002c45c79b73ec77a1df2baee8d4caa07c0c0 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Thu, 5 Nov 2020 18:31:02 +0000 Subject: [PATCH] keras: use kwargs for all bars --- tqdm/keras.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tqdm/keras.py b/tqdm/keras.py index b54e7563..c70ed3c3 100644 --- a/tqdm/keras.py +++ b/tqdm/keras.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division from .auto import tqdm as tqdm_auto from copy import copy +from functools import partial try: import keras except ImportError as e: @@ -44,12 +45,12 @@ class TqdmCallback(keras.callbacks.Callback): tqdm_class : optional `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. tqdm_kwargs : optional - Any other arguments used for initial bars. - Instead, for passing arguments to all bars, create a custom - `tqdm_class`. + Any other arguments used for all bars. """ + if tqdm_kwargs: + tqdm_class = partial(tqdm_class, **tqdm_kwargs) self.tqdm_class = tqdm_class - self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs) + self.epoch_bar = tqdm_class(total=epochs, unit='epoch') self.on_epoch_end = self.bar2callback(self.epoch_bar) if data_size and batch_size: self.batches = batches = (data_size + batch_size - 1) // batch_size @@ -58,7 +59,7 @@ class TqdmCallback(keras.callbacks.Callback): self.verbose = verbose if verbose == 1: self.batch_bar = tqdm_class(total=batches, unit='batch', - leave=False, **tqdm_kwargs) + leave=False) self.on_batch_end = self.bar2callback( self.batch_bar, pop=['batch', 'size'],