From 00474472944944b346d8409cfded84bb299f601a Mon Sep 17 00:00:00 2001 From: Pablo Galindo Salgado Date: Sun, 24 Jul 2022 15:58:52 +0100 Subject: [PATCH] gh-95185: Check recursion depth in the AST constructor (#95186) Co-authored-by: Serhiy Storchaka --- Include/internal/pycore_ast_state.h | 2 + Lib/test/test_ast.py | 21 ++++ ...2-07-24-00-27-47.gh-issue-95185.ghYTZx.rst | 3 + Parser/asdl_c.py | 37 +++++- Python/Python-ast.c | 107 +++++++++++++++++- 5 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst diff --git a/Include/internal/pycore_ast_state.h b/Include/internal/pycore_ast_state.h index da78bba3b69..f15b4905eed 100644 --- a/Include/internal/pycore_ast_state.h +++ b/Include/internal/pycore_ast_state.h @@ -12,6 +12,8 @@ extern "C" { struct ast_state { int initialized; + int recursion_depth; + int recursion_limit; PyObject *AST_type; PyObject *Add_singleton; PyObject *Add_type; diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 480089aa8af..9734218c21b 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -793,6 +793,27 @@ def next(self): return self enum._test_simple_enum(_Precedence, ast._Precedence) + @support.cpython_only + def test_ast_recursion_limit(self): + fail_depth = sys.getrecursionlimit() * 3 + crash_depth = sys.getrecursionlimit() * 300 + success_depth = int(fail_depth * 0.75) + + def check_limit(prefix, repeated): + expect_ok = prefix + repeated * success_depth + ast.parse(expect_ok) + for depth in (fail_depth, crash_depth): + broken = prefix + repeated * depth + details = "Compiling ({!r} + {!r} * {})".format( + prefix, repeated, depth) + with self.assertRaises(RecursionError, msg=details): + ast.parse(broken) + + check_limit("a", "()") + check_limit("a", ".b") + check_limit("a", "[0]") + check_limit("a", "*a") + class ASTHelpers_Test(unittest.TestCase): maxDiff = None diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst b/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst new file mode 100644 index 00000000000..de156bab2f5 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2022-07-24-00-27-47.gh-issue-95185.ghYTZx.rst @@ -0,0 +1,3 @@ +Prevented crashes in the AST constructor when compiling some absurdly long +expressions like ``"+0"*1000000``. :exc:`RecursionError` is now raised +instead. Patch by Pablo Galindo diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index bf391a3ae16..13dd44ca0cd 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -1112,6 +1112,8 @@ def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) self.file.write(textwrap.dedent(''' + state->recursion_depth = 0; + state->recursion_limit = 0; state->initialized = 1; return 1; } @@ -1259,8 +1261,14 @@ def func_begin(self, name): self.emit('if (!o) {', 1) self.emit("Py_RETURN_NONE;", 2) self.emit("}", 1) + self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) + self.emit("PyErr_SetString(PyExc_RecursionError,", 2) + self.emit('"maximum recursion depth exceeded during ast construction");', 3) + self.emit("return 0;", 2) + self.emit("}", 1) def func_end(self): + self.emit("state->recursion_depth--;", 1) self.emit("return result;", 1) self.emit("failed:", 0) self.emit("Py_XDECREF(value);", 1) @@ -1371,7 +1379,32 @@ class PartingShots(StaticVisitor): if (state == NULL) { return NULL; } - return ast2obj_mod(state, t); + + int recursion_limit = Py_GetRecursionLimit(); + int starting_recursion_depth; + /* Be careful here to prevent overflow. */ + int COMPILER_STACK_FRAME_SCALE = 3; + PyThreadState *tstate = _PyThreadState_GET(); + if (!tstate) { + return 0; + } + state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? + recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; + int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining; + starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? + recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth; + state->recursion_depth = starting_recursion_depth; + + PyObject *result = ast2obj_mod(state, t); + + /* Check that the recursion depth counting balanced correctly */ + if (result && state->recursion_depth != starting_recursion_depth) { + PyErr_Format(PyExc_SystemError, + "AST constructor recursion depth mismatch (before=%d, after=%d)", + starting_recursion_depth, state->recursion_depth); + return 0; + } + return result; } /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ @@ -1437,6 +1470,8 @@ def visit(self, object): def generate_ast_state(module_state, f): f.write('struct ast_state {\n') f.write(' int initialized;\n') + f.write(' int recursion_depth;\n') + f.write(' int recursion_limit;\n') for s in module_state: f.write(' PyObject *' + s + ';\n') f.write('};') diff --git a/Python/Python-ast.c b/Python/Python-ast.c index e52a72d43bc..f485af675cc 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -1851,6 +1851,8 @@ init_types(struct ast_state *state) "TypeIgnore(int lineno, string tag)"); if (!state->TypeIgnore_type) return 0; + state->recursion_depth = 0; + state->recursion_limit = 0; state->initialized = 1; return 1; } @@ -3610,6 +3612,11 @@ ast2obj_mod(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case Module_kind: tp = (PyTypeObject *)state->Module_type; @@ -3665,6 +3672,7 @@ ast2obj_mod(struct ast_state *state, void* _o) Py_DECREF(value); break; } + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -3681,6 +3689,11 @@ ast2obj_stmt(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case FunctionDef_kind: tp = (PyTypeObject *)state->FunctionDef_type; @@ -4224,6 +4237,7 @@ ast2obj_stmt(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -4240,6 +4254,11 @@ ast2obj_expr(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case BoolOp_kind: tp = (PyTypeObject *)state->BoolOp_type; @@ -4701,6 +4720,7 @@ ast2obj_expr(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -4843,6 +4863,11 @@ ast2obj_comprehension(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->comprehension_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -4866,6 +4891,7 @@ ast2obj_comprehension(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->is_async, value) == -1) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -4882,6 +4908,11 @@ ast2obj_excepthandler(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case ExceptHandler_kind: tp = (PyTypeObject *)state->ExceptHandler_type; @@ -4925,6 +4956,7 @@ ast2obj_excepthandler(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -4941,6 +4973,11 @@ ast2obj_arguments(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->arguments_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -4979,6 +5016,7 @@ ast2obj_arguments(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->defaults, value) == -1) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -4995,6 +5033,11 @@ ast2obj_arg(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->arg_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -5033,6 +5076,7 @@ ast2obj_arg(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5049,6 +5093,11 @@ ast2obj_keyword(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->keyword_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -5082,6 +5131,7 @@ ast2obj_keyword(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5098,6 +5148,11 @@ ast2obj_alias(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->alias_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -5131,6 +5186,7 @@ ast2obj_alias(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5147,6 +5203,11 @@ ast2obj_withitem(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->withitem_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -5160,6 +5221,7 @@ ast2obj_withitem(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->optional_vars, value) == -1) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5176,6 +5238,11 @@ ast2obj_match_case(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } tp = (PyTypeObject *)state->match_case_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) return NULL; @@ -5194,6 +5261,7 @@ ast2obj_match_case(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->body, value) == -1) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5210,6 +5278,11 @@ ast2obj_pattern(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case MatchValue_kind: tp = (PyTypeObject *)state->MatchValue_type; @@ -5349,6 +5422,7 @@ ast2obj_pattern(struct ast_state *state, void* _o) if (PyObject_SetAttr(result, state->end_col_offset, value) < 0) goto failed; Py_DECREF(value); + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -5365,6 +5439,11 @@ ast2obj_type_ignore(struct ast_state *state, void* _o) if (!o) { Py_RETURN_NONE; } + if (++state->recursion_depth > state->recursion_limit) { + PyErr_SetString(PyExc_RecursionError, + "maximum recursion depth exceeded during ast construction"); + return 0; + } switch (o->kind) { case TypeIgnore_kind: tp = (PyTypeObject *)state->TypeIgnore_type; @@ -5382,6 +5461,7 @@ ast2obj_type_ignore(struct ast_state *state, void* _o) Py_DECREF(value); break; } + state->recursion_depth--; return result; failed: Py_XDECREF(value); @@ -12234,7 +12314,32 @@ PyObject* PyAST_mod2obj(mod_ty t) if (state == NULL) { return NULL; } - return ast2obj_mod(state, t); + + int recursion_limit = Py_GetRecursionLimit(); + int starting_recursion_depth; + /* Be careful here to prevent overflow. */ + int COMPILER_STACK_FRAME_SCALE = 3; + PyThreadState *tstate = _PyThreadState_GET(); + if (!tstate) { + return 0; + } + state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? + recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; + int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining; + starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? + recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth; + state->recursion_depth = starting_recursion_depth; + + PyObject *result = ast2obj_mod(state, t); + + /* Check that the recursion depth counting balanced correctly */ + if (result && state->recursion_depth != starting_recursion_depth) { + PyErr_Format(PyExc_SystemError, + "AST constructor recursion depth mismatch (before=%d, after=%d)", + starting_recursion_depth, state->recursion_depth); + return 0; + } + return result; } /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */