mirror of https://github.com/tqdm/tqdm.git
keras: use kwargs for all bars
This commit is contained in:
parent
983a05b495
commit
669002c45c
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import absolute_import, division
|
from __future__ import absolute_import, division
|
||||||
from .auto import tqdm as tqdm_auto
|
from .auto import tqdm as tqdm_auto
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from functools import partial
|
||||||
try:
|
try:
|
||||||
import keras
|
import keras
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -44,12 +45,12 @@ class TqdmCallback(keras.callbacks.Callback):
|
||||||
tqdm_class : optional
|
tqdm_class : optional
|
||||||
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
|
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
|
||||||
tqdm_kwargs : optional
|
tqdm_kwargs : optional
|
||||||
Any other arguments used for initial bars.
|
Any other arguments used for all bars.
|
||||||
Instead, for passing arguments to all bars, create a custom
|
|
||||||
`tqdm_class`.
|
|
||||||
"""
|
"""
|
||||||
|
if tqdm_kwargs:
|
||||||
|
tqdm_class = partial(tqdm_class, **tqdm_kwargs)
|
||||||
self.tqdm_class = tqdm_class
|
self.tqdm_class = tqdm_class
|
||||||
self.epoch_bar = tqdm_class(total=epochs, unit='epoch', **tqdm_kwargs)
|
self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
|
||||||
self.on_epoch_end = self.bar2callback(self.epoch_bar)
|
self.on_epoch_end = self.bar2callback(self.epoch_bar)
|
||||||
if data_size and batch_size:
|
if data_size and batch_size:
|
||||||
self.batches = batches = (data_size + batch_size - 1) // batch_size
|
self.batches = batches = (data_size + batch_size - 1) // batch_size
|
||||||
|
@ -58,7 +59,7 @@ class TqdmCallback(keras.callbacks.Callback):
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
if verbose == 1:
|
if verbose == 1:
|
||||||
self.batch_bar = tqdm_class(total=batches, unit='batch',
|
self.batch_bar = tqdm_class(total=batches, unit='batch',
|
||||||
leave=False, **tqdm_kwargs)
|
leave=False)
|
||||||
self.on_batch_end = self.bar2callback(
|
self.on_batch_end = self.bar2callback(
|
||||||
self.batch_bar,
|
self.batch_bar,
|
||||||
pop=['batch', 'size'],
|
pop=['batch', 'size'],
|
||||||
|
|
Loading…
Reference in New Issue