From 9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6 Mon Sep 17 00:00:00 2001 From: Erlend Egeberg Aasland Date: Tue, 12 Apr 2022 02:55:59 +0200 Subject: [PATCH] gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903) --- Doc/includes/sqlite3/sumintwindow.py | 46 +++++ Doc/library/sqlite3.rst | 29 +++ Doc/whatsnew/3.11.rst | 4 + Lib/test/test_sqlite3/test_dbapi.py | 2 + Lib/test/test_sqlite3/test_userfunctions.py | 168 +++++++++++++++- .../2020-05-24-23-52-03.bpo-40617.lycF9q.rst | 3 + Modules/_sqlite/clinic/connection.c.h | 53 +++++- Modules/_sqlite/connection.c | 179 +++++++++++++++++- Modules/_sqlite/module.c | 4 + Modules/_sqlite/module.h | 2 + 10 files changed, 477 insertions(+), 13 deletions(-) create mode 100644 Doc/includes/sqlite3/sumintwindow.py create mode 100644 Misc/NEWS.d/next/Library/2020-05-24-23-52-03.bpo-40617.lycF9q.rst diff --git a/Doc/includes/sqlite3/sumintwindow.py b/Doc/includes/sqlite3/sumintwindow.py new file mode 100644 index 00000000000..0e915d6cc6a --- /dev/null +++ b/Doc/includes/sqlite3/sumintwindow.py @@ -0,0 +1,46 @@ +# Example taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc +import sqlite3 + + +class WindowSumInt: + def __init__(self): + self.count = 0 + + def step(self, value): + """Adds a row to the current window.""" + self.count += value + + def value(self): + """Returns the current value of the aggregate.""" + return self.count + + def inverse(self, value): + """Removes a row from the current window.""" + self.count -= value + + def finalize(self): + """Returns the final value of the aggregate. + + Any clean-up actions should be placed here. + """ + return self.count + + +con = sqlite3.connect(":memory:") +cur = con.execute("create table test(x, y)") +values = [ + ("a", 4), + ("b", 5), + ("c", 3), + ("d", 8), + ("e", 1), +] +cur.executemany("insert into test values(?, ?)", values) +con.create_window_function("sumint", 1, WindowSumInt) +cur.execute(""" + select x, sumint(y) over ( + order by x rows between 1 preceding and 1 following + ) as sum_y + from test order by x +""") +print(cur.fetchall()) diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst index 852b68437a2..60dfbefd2e2 100644 --- a/Doc/library/sqlite3.rst +++ b/Doc/library/sqlite3.rst @@ -473,6 +473,35 @@ Connection Objects .. literalinclude:: ../includes/sqlite3/mysumaggr.py + .. method:: create_window_function(name, num_params, aggregate_class, /) + + Creates user-defined aggregate window function *name*. + + *aggregate_class* must implement the following methods: + + * ``step``: adds a row to the current window + * ``value``: returns the current value of the aggregate + * ``inverse``: removes a row from the current window + * ``finalize``: returns the final value of the aggregate + + ``step`` and ``value`` accept *num_params* number of parameters, + unless *num_params* is ``-1``, in which case they may take any number of + arguments. ``finalize`` and ``value`` can return any of the types + supported by SQLite: + :class:`bytes`, :class:`str`, :class:`int`, :class:`float`, and + :const:`None`. Call :meth:`create_window_function` with + *aggregate_class* set to :const:`None` to clear window function *name*. + + Aggregate window functions are supported by SQLite 3.25.0 and higher. + :exc:`NotSupportedError` will be raised if used with older versions. + + .. versionadded:: 3.11 + + Example: + + .. literalinclude:: ../includes/sqlite3/sumintwindow.py + + .. method:: create_collation(name, callable) Creates a collation with the specified *name* and *callable*. The callable will diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index 354e2112338..d803801f273 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -389,6 +389,10 @@ sqlite3 serializing and deserializing databases. (Contributed by Erlend E. Aasland in :issue:`41930`.) +* Add :meth:`~sqlite3.Connection.create_window_function` to + :class:`sqlite3.Connection` for creating aggregate window functions. + (Contributed by Erlend E. Aasland in :issue:`34916`.) + sys --- diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 02482816cb9..2d2e58a3d44 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -1084,6 +1084,8 @@ def test_check_connection_thread(self): if hasattr(sqlite.Connection, "serialize"): fns.append(lambda: self.con.serialize()) fns.append(lambda: self.con.deserialize(b"")) + if sqlite.sqlite_version_info >= (3, 25, 0): + fns.append(lambda: self.con.create_window_function("foo", 0, None)) for fn in fns: with self.subTest(fn=fn): diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 9070c9e01b2..0970b0378ad 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -27,9 +27,9 @@ import re import sys import unittest -import unittest.mock import sqlite3 as sqlite +from unittest.mock import Mock, patch from test.support import bigmemtest, catch_unraisable_exception, gc_collect from test.test_sqlite3.test_dbapi import cx_limit @@ -393,7 +393,7 @@ def append_result(arg): # indices, which allows testing based on syntax, iso. the query optimizer. @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") def test_func_non_deterministic(self): - mock = unittest.mock.Mock(return_value=None) + mock = Mock(return_value=None) self.con.create_function("nondeterministic", 0, mock, deterministic=False) if sqlite.sqlite_version_info < (3, 15, 0): self.con.execute("select nondeterministic() = nondeterministic()") @@ -404,7 +404,7 @@ def test_func_non_deterministic(self): @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") def test_func_deterministic(self): - mock = unittest.mock.Mock(return_value=None) + mock = Mock(return_value=None) self.con.create_function("deterministic", 0, mock, deterministic=True) if sqlite.sqlite_version_info < (3, 15, 0): self.con.execute("select deterministic() = deterministic()") @@ -482,6 +482,164 @@ def test_func_return_illegal_value(self): self.con.execute, "select badreturn()") +class WindowSumInt: + def __init__(self): + self.count = 0 + + def step(self, value): + self.count += value + + def value(self): + return self.count + + def inverse(self, value): + self.count -= value + + def finalize(self): + return self.count + +class BadWindow(Exception): + pass + + +@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0), + "Requires SQLite 3.25.0 or newer") +class WindowFunctionTests(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + + # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc + values = [ + ("a", 4), + ("b", 5), + ("c", 3), + ("d", 8), + ("e", 1), + ] + with self.con: + self.con.execute("create table test(x, y)") + self.con.executemany("insert into test values(?, ?)", values) + self.expected = [ + ("a", 9), + ("b", 12), + ("c", 16), + ("d", 12), + ("e", 9), + ] + self.query = """ + select x, %s(y) over ( + order by x rows between 1 preceding and 1 following + ) as sum_y + from test order by x + """ + self.con.create_window_function("sumint", 1, WindowSumInt) + + def test_win_sum_int(self): + self.cur.execute(self.query % "sumint") + self.assertEqual(self.cur.fetchall(), self.expected) + + def test_win_error_on_create(self): + self.assertRaises(sqlite.ProgrammingError, + self.con.create_window_function, + "shouldfail", -100, WindowSumInt) + + @with_tracebacks(BadWindow) + def test_win_exception_in_method(self): + for meth in "__init__", "step", "value", "inverse": + with self.subTest(meth=meth): + with patch.object(WindowSumInt, meth, side_effect=BadWindow): + name = f"exc_{meth}" + self.con.create_window_function(name, 1, WindowSumInt) + msg = f"'{meth}' method raised error" + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(BadWindow) + def test_win_exception_in_finalize(self): + # Note: SQLite does not (as of version 3.38.0) propagate finalize + # callback errors to sqlite3_step(); this implies that OperationalError + # is _not_ raised. + with patch.object(WindowSumInt, "finalize", side_effect=BadWindow): + name = f"exception_in_finalize" + self.con.create_window_function(name, 1, WindowSumInt) + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(AttributeError) + def test_win_missing_method(self): + class MissingValue: + def step(self, x): pass + def inverse(self, x): pass + def finalize(self): return 42 + + class MissingInverse: + def step(self, x): pass + def value(self): return 42 + def finalize(self): return 42 + + class MissingStep: + def value(self): return 42 + def inverse(self, x): pass + def finalize(self): return 42 + + dataset = ( + ("step", MissingStep), + ("value", MissingValue), + ("inverse", MissingInverse), + ) + for meth, cls in dataset: + with self.subTest(meth=meth, cls=cls): + name = f"exc_{meth}" + self.con.create_window_function(name, 1, cls) + with self.assertRaisesRegex(sqlite.OperationalError, + f"'{meth}' method not defined"): + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(AttributeError) + def test_win_missing_finalize(self): + # Note: SQLite does not (as of version 3.38.0) propagate finalize + # callback errors to sqlite3_step(); this implies that OperationalError + # is _not_ raised. + class MissingFinalize: + def step(self, x): pass + def value(self): return 42 + def inverse(self, x): pass + + name = "missing_finalize" + self.con.create_window_function(name, 1, MissingFinalize) + self.cur.execute(self.query % name) + self.cur.fetchall() + + def test_win_clear_function(self): + self.con.create_window_function("sumint", 1, None) + self.assertRaises(sqlite.OperationalError, self.cur.execute, + self.query % "sumint") + + def test_win_redefine_function(self): + # Redefine WindowSumInt; adjust the expected results accordingly. + class Redefined(WindowSumInt): + def step(self, value): self.count += value * 2 + def inverse(self, value): self.count -= value * 2 + expected = [(v[0], v[1]*2) for v in self.expected] + + self.con.create_window_function("sumint", 1, Redefined) + self.cur.execute(self.query % "sumint") + self.assertEqual(self.cur.fetchall(), expected) + + def test_win_error_value_return(self): + class ErrorValueReturn: + def __init__(self): pass + def step(self, x): pass + def value(self): return 1 << 65 + + self.con.create_window_function("err_val_ret", 1, ErrorValueReturn) + self.assertRaisesRegex(sqlite.DataError, "string or blob too big", + self.cur.execute, self.query % "err_val_ret") + + class AggregateTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -527,10 +685,10 @@ def test_aggr_no_step(self): def test_aggr_no_finalize(self): cur = self.con.cursor() - with self.assertRaises(sqlite.OperationalError) as cm: + msg = "user-defined aggregate's 'finalize' method not defined" + with self.assertRaisesRegex(sqlite.OperationalError, msg): cur.execute("select nofinalize(t) from test") val = cur.fetchone()[0] - self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") def test_aggr_exception_in_init(self): diff --git a/Misc/NEWS.d/next/Library/2020-05-24-23-52-03.bpo-40617.lycF9q.rst b/Misc/NEWS.d/next/Library/2020-05-24-23-52-03.bpo-40617.lycF9q.rst new file mode 100644 index 00000000000..123b49ddb5a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-05-24-23-52-03.bpo-40617.lycF9q.rst @@ -0,0 +1,3 @@ +Add :meth:`~sqlite3.Connection.create_window_function` to +:class:`sqlite3.Connection` for creating aggregate window functions. +Patch by Erlend E. Aasland. diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h index 99ef94ecd71..2b933f85224 100644 --- a/Modules/_sqlite/clinic/connection.c.h +++ b/Modules/_sqlite/clinic/connection.c.h @@ -235,6 +235,53 @@ exit: return return_value; } +#if defined(HAVE_WINDOW_FUNCTIONS) + +PyDoc_STRVAR(create_window_function__doc__, +"create_window_function($self, name, num_params, aggregate_class, /)\n" +"--\n" +"\n" +"Creates or redefines an aggregate window function. Non-standard.\n" +"\n" +" name\n" +" The name of the SQL aggregate window function to be created or\n" +" redefined.\n" +" num_params\n" +" The number of arguments the step and inverse methods takes.\n" +" aggregate_class\n" +" A class with step(), finalize(), value(), and inverse() methods.\n" +" Set to None to clear the window function."); + +#define CREATE_WINDOW_FUNCTION_METHODDEF \ + {"create_window_function", (PyCFunction)(void(*)(void))create_window_function, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, create_window_function__doc__}, + +static PyObject * +create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, + const char *name, int num_params, + PyObject *aggregate_class); + +static PyObject * +create_window_function(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + static const char * const _keywords[] = {"", "", "", NULL}; + static _PyArg_Parser _parser = {"siO:create_window_function", _keywords, 0}; + const char *name; + int num_params; + PyObject *aggregate_class; + + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser, + &name, &num_params, &aggregate_class)) { + goto exit; + } + return_value = create_window_function_impl(self, cls, name, num_params, aggregate_class); + +exit: + return return_value; +} + +#endif /* defined(HAVE_WINDOW_FUNCTIONS) */ + PyDoc_STRVAR(pysqlite_connection_create_aggregate__doc__, "create_aggregate($self, /, name, n_arg, aggregate_class)\n" "--\n" @@ -975,6 +1022,10 @@ exit: return return_value; } +#ifndef CREATE_WINDOW_FUNCTION_METHODDEF + #define CREATE_WINDOW_FUNCTION_METHODDEF +#endif /* !defined(CREATE_WINDOW_FUNCTION_METHODDEF) */ + #ifndef PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF #define PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF #endif /* !defined(PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF) */ @@ -990,4 +1041,4 @@ exit: #ifndef DESERIALIZE_METHODDEF #define DESERIALIZE_METHODDEF #endif /* !defined(DESERIALIZE_METHODDEF) */ -/*[clinic end generated code: output=d965a68f9229a56c input=a9049054013a1b77]*/ +/*[clinic end generated code: output=b9af1b52fda808bf input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 9d187cfa99d..d7c0a9e4616 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -33,6 +33,10 @@ #define HAVE_TRACE_V2 #endif +#if SQLITE_VERSION_NUMBER >= 3025000 +#define HAVE_WINDOW_FUNCTIONS +#endif + static const char * get_isolation_level(const char *level) { @@ -799,7 +803,7 @@ final_callback(sqlite3_context *context) goto error; } - /* Keep the exception (if any) of the last call to step() */ + // Keep the exception (if any) of the last call to step, value, or inverse PyErr_Fetch(&exception, &value, &tb); callback_context *ctx = (callback_context *)sqlite3_user_data(context); @@ -814,13 +818,20 @@ final_callback(sqlite3_context *context) Py_DECREF(function_result); } if (!ok) { - set_sqlite_error(context, - "user-defined aggregate's 'finalize' method raised error"); - } + int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); + _PyErr_ChainExceptions(exception, value, tb); - /* Restore the exception (if any) of the last call to step(), - but clear also the current exception if finalize() failed */ - PyErr_Restore(exception, value, tb); + /* Note: contrary to the step, value, and inverse callbacks, SQLite + * does _not_, as of SQLite 3.38.0, propagate errors to sqlite3_step() + * from the finalize callback. This implies that execute*() will not + * raise OperationalError, as it normally would. */ + set_sqlite_error(context, attr_err + ? "user-defined aggregate's 'finalize' method not defined" + : "user-defined aggregate's 'finalize' method raised error"); + } + else { + PyErr_Restore(exception, value, tb); + } error: PyGILState_Release(threadstate); @@ -968,6 +979,159 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, Py_RETURN_NONE; } +#ifdef HAVE_WINDOW_FUNCTIONS +/* + * Regarding the 'inverse' aggregate callback: + * This method is only required by window aggregate functions, not + * ordinary aggregate function implementations. It is invoked to remove + * a row from the current window. The function arguments, if any, + * correspond to the row being removed. + */ +static void +inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) +{ + PyGILState_STATE gilstate = PyGILState_Ensure(); + + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + + int size = sizeof(PyObject *); + PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); + assert(cls != NULL); + assert(*cls != NULL); + + PyObject *method = PyObject_GetAttr(*cls, ctx->state->str_inverse); + if (method == NULL) { + set_sqlite_error(context, + "user-defined aggregate's 'inverse' method not defined"); + goto exit; + } + + PyObject *args = _pysqlite_build_py_params(context, argc, params); + if (args == NULL) { + set_sqlite_error(context, + "unable to build arguments for user-defined aggregate's " + "'inverse' method"); + goto exit; + } + + PyObject *res = PyObject_CallObject(method, args); + Py_DECREF(args); + if (res == NULL) { + set_sqlite_error(context, + "user-defined aggregate's 'inverse' method raised error"); + goto exit; + } + Py_DECREF(res); + +exit: + Py_XDECREF(method); + PyGILState_Release(gilstate); +} + +/* + * Regarding the 'value' aggregate callback: + * This method is only required by window aggregate functions, not + * ordinary aggregate function implementations. It is invoked to return + * the current value of the aggregate. + */ +static void +value_callback(sqlite3_context *context) +{ + PyGILState_STATE gilstate = PyGILState_Ensure(); + + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + + int size = sizeof(PyObject *); + PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); + assert(cls != NULL); + assert(*cls != NULL); + + PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value); + if (res == NULL) { + int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); + set_sqlite_error(context, attr_err + ? "user-defined aggregate's 'value' method not defined" + : "user-defined aggregate's 'value' method raised error"); + } + else { + int rc = _pysqlite_set_result(context, res); + Py_DECREF(res); + if (rc < 0) { + set_sqlite_error(context, + "unable to set result from user-defined aggregate's " + "'value' method"); + } + } + + PyGILState_Release(gilstate); +} + +/*[clinic input] +_sqlite3.Connection.create_window_function as create_window_function + + cls: defining_class + name: str + The name of the SQL aggregate window function to be created or + redefined. + num_params: int + The number of arguments the step and inverse methods takes. + aggregate_class: object + A class with step(), finalize(), value(), and inverse() methods. + Set to None to clear the window function. + / + +Creates or redefines an aggregate window function. Non-standard. +[clinic start generated code]*/ + +static PyObject * +create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, + const char *name, int num_params, + PyObject *aggregate_class) +/*[clinic end generated code: output=5332cd9464522235 input=46d57a54225b5228]*/ +{ + if (sqlite3_libversion_number() < 3025000) { + PyErr_SetString(self->NotSupportedError, + "create_window_function() requires " + "SQLite 3.25.0 or higher"); + return NULL; + } + + if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { + return NULL; + } + + int flags = SQLITE_UTF8; + int rc; + if (Py_IsNone(aggregate_class)) { + rc = sqlite3_create_window_function(self->db, name, num_params, flags, + 0, 0, 0, 0, 0, 0); + } + else { + callback_context *ctx = create_callback_context(cls, aggregate_class); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_create_window_function(self->db, name, num_params, flags, + ctx, + &step_callback, + &final_callback, + &value_callback, + &inverse_callback, + &destructor_callback); + } + + if (rc != SQLITE_OK) { + // Errors are not set on the database connection, so we cannot + // use _pysqlite_seterror(). + PyErr_SetString(self->ProgrammingError, sqlite3_errstr(rc)); + return NULL; + } + Py_RETURN_NONE; +} +#endif + /*[clinic input] _sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate @@ -2092,6 +2256,7 @@ static PyMethodDef connection_methods[] = { GETLIMIT_METHODDEF SERIALIZE_METHODDEF DESERIALIZE_METHODDEF + CREATE_WINDOW_FUNCTION_METHODDEF {NULL, NULL} }; diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 07f090c4a26..ffda836d7a3 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -630,8 +630,10 @@ module_clear(PyObject *module) Py_CLEAR(state->str___conform__); Py_CLEAR(state->str_executescript); Py_CLEAR(state->str_finalize); + Py_CLEAR(state->str_inverse); Py_CLEAR(state->str_step); Py_CLEAR(state->str_upper); + Py_CLEAR(state->str_value); return 0; } @@ -717,8 +719,10 @@ module_exec(PyObject *module) ADD_INTERNED(state, __conform__); ADD_INTERNED(state, executescript); ADD_INTERNED(state, finalize); + ADD_INTERNED(state, inverse); ADD_INTERNED(state, step); ADD_INTERNED(state, upper); + ADD_INTERNED(state, value); /* Set error constants */ if (add_error_constants(module) < 0) { diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index cca52d1e04b..fcea7096924 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -64,8 +64,10 @@ typedef struct { PyObject *str___conform__; PyObject *str_executescript; PyObject *str_finalize; + PyObject *str_inverse; PyObject *str_step; PyObject *str_upper; + PyObject *str_value; } pysqlite_state; extern pysqlite_state pysqlite_global_state;