diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 56eb452b0b0..befc6beb055 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -90,13 +90,15 @@ loops that truncate the stream. parameter (which defaults to :const:`0`). Elements may be any addable type including :class:`Decimal` or :class:`Fraction`. Equivalent to:: - def accumulate(iterable, start=0): + def accumulate(iterable): 'Return running totals' - # accumulate([1,2,3,4,5]) --> 1 3 6 10 15 - total = start - for element in iterable: - total += element - yield total + # accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + it = iter(iterable) + total = next(it) + yield total + for element in it: + total += element + yield total .. versionadded:: 3.2 diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 8a67cff60ce..b8f6eecbbeb 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -59,18 +59,18 @@ class TestBasicOps(unittest.TestCase): def test_accumulate(self): self.assertEqual(list(accumulate(range(10))), # one positional arg - [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) - self.assertEqual(list(accumulate(range(10), 100)), # two positional args - [100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) - self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args - [100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) + [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) + self.assertEqual(list(accumulate(iterable=range(10))), # kw arg + [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) for typ in int, complex, Decimal, Fraction: # multiple types - self.assertEqual(list(accumulate(range(10), typ(0))), + self.assertEqual( + list(accumulate(map(typ, range(10)))), list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]))) self.assertEqual(list(accumulate([])), []) # empty iterable - self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args + self.assertEqual(list(accumulate([7])), [7]) # iterable of length one + self.assertRaises(TypeError, accumulate, range(10), 5) # too many args self.assertRaises(TypeError, accumulate) # too few args - self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args + self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add def test_chain(self): diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 04bfffc5b0d..b202e5262ba 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2597,41 +2597,27 @@ static PyTypeObject accumulate_type; static PyObject * accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - static char *kwargs[] = {"iterable", "start", NULL}; + static char *kwargs[] = {"iterable", NULL}; PyObject *iterable; PyObject *it; - PyObject *start = NULL; accumulateobject *lz; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", - kwargs, &iterable, &start)) - return NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable)) + return NULL; /* Get iterator. */ it = PyObject_GetIter(iterable); if (it == NULL) return NULL; - /* Default start value */ - if (start == NULL) { - start = PyLong_FromLong(0); - if (start == NULL) { - Py_DECREF(it); - return NULL; - } - } else { - Py_INCREF(start); - } - /* create accumulateobject structure */ lz = (accumulateobject *)type->tp_alloc(type, 0); if (lz == NULL) { Py_DECREF(it); - Py_DECREF(start); - return NULL; + return NULL; } - lz->total = start; + lz->total = NULL; lz->it = it; return (PyObject *)lz; } @@ -2661,11 +2647,17 @@ accumulate_next(accumulateobject *lz) val = PyIter_Next(lz->it); if (val == NULL) return NULL; - + + if (lz->total == NULL) { + Py_INCREF(val); + lz->total = val; + return lz->total; + } + newtotal = PyNumber_Add(lz->total, val); - Py_DECREF(val); + Py_DECREF(val); if (newtotal == NULL) - return NULL; + return NULL; oldtotal = lz->total; lz->total = newtotal; @@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz) } PyDoc_STRVAR(accumulate_doc, -"accumulate(iterable, start=0) --> accumulate object\n\ +"accumulate(iterable) --> accumulate object\n\ \n\ Return series of accumulated sums.");