diff --git a/tqdm/_tqdm_pandas.py b/tqdm/_tqdm_pandas.py index 59632d9a..5e9cd429 100644 --- a/tqdm/_tqdm_pandas.py +++ b/tqdm/_tqdm_pandas.py @@ -1,7 +1,7 @@ # future division is important to divide integers and get as # a result precise floating numbers (instead of truncated int) from __future__ import absolute_import - +from __future__ import division __author__ = "github.com/casperdcl" __all__ = ['tqdm_pandas'] @@ -29,20 +29,25 @@ def tqdm_pandas(t): # pragma: no cover https://stackoverflow.com/questions/18603270/ progress-indicator-during-pandas-operations-python """ + from pandas.core.frame import DataFrame from pandas.core.groupby import DataFrameGroupBy def inner(groups, func, *args, **kwargs): """ Parameters ---------- - groups : DataFrameGroupBy - Grouped data. + groups : DataFrame[GroupBy] + (Grouped) data. func : function - To be applied on the grouped data. + To be applied on the (grouped) data. *args and *kwargs are transmitted to DataFrameGroupBy.apply() """ - t.total = len(groups) + 1 # pandas calls update once too many + t.total = getattr(groups, 'ngroups', None) + if t.total is None: # not grouped + t.total = groups.size // len(groups) + else: + t.total += 1 # pandas calls update once too many def wrapper(*args, **kwargs): t.update() @@ -55,4 +60,5 @@ def tqdm_pandas(t): # pragma: no cover return result # Enable custom tqdm progress in pandas! + DataFrame.progress_apply = inner DataFrameGroupBy.progress_apply = inner diff --git a/tqdm/tests/tests_pandas.py b/tqdm/tests/tests_pandas.py index 7737ad4d..9f22d1a5 100644 --- a/tqdm/tests/tests_pandas.py +++ b/tqdm/tests/tests_pandas.py @@ -30,6 +30,29 @@ def test_pandas(): nexres, our_file.read())) +@with_setup(pretest, posttest) +def test_pandas(): + """ Test pandas.DataFrame.progress_apply """ + try: + from numpy.random import randint + from tqdm import tqdm_pandas + import pandas as pd + except: + raise SkipTest + + with closing(StringIO()) as our_file: + df = pd.DataFrame(randint(0, 100, (1000, 6))) + tqdm_pandas(tqdm(file=our_file, leave=True, ascii=True)) + df.progress_apply(lambda x: None) + + our_file.seek(0) + + if '/6' not in our_file.read(): + our_file.seek(0) + raise AssertionError("\nExpected:\n{0}\nIn:{1}\n".format( + '/6', our_file.read())) + + @with_setup(pretest, posttest) def test_pandas_leave(): """ Test pandas with `leave=True` """