diff --git a/Objects/object.c b/Objects/object.c index 20950c1e8d8..7bf2a634488 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -376,40 +376,28 @@ try_rich_compare_bool(PyObject *v, PyObject *w, int op) static int try_rich_to_3way_compare(PyObject *v, PyObject *w) { + static struct { int op; int outcome; } tries[3] = { + /* Try this operator, and if it is true, use this outcome: */ + {Py_EQ, 0}, + {Py_LT, -1}, + {Py_GT, 1}, + }; + int i; + if (v->ob_type->tp_richcompare == NULL && w->ob_type->tp_richcompare == NULL) return 2; /* Shortcut */ - switch (try_rich_compare_bool(v, w, Py_LT)) { - case -1: /* Error */ - return -1; - case 0: /* False: not less */ - break; - case 1: /* True: less */ - return -1; - case 2: /* NotImplemented */ - break; + + for (i = 0; i < 3; i++) { + switch (try_rich_compare_bool(v, w, tries[i].op)) { + case -1: + return -1; + case 1: + return tries[i].outcome; + } } - switch (try_rich_compare_bool(v, w, Py_GT)) { - case -1: /* Error */ - return -1; - case 0: /* False: not greater */ - break; - case 1: /* True: greater */ - return 1; - case 2: /* NotImplemented */ - break; - } - switch (try_rich_compare_bool(v, w, Py_EQ)) { - case -1: /* Error */ - return -1; - case 0: /* False: not equal */ - break; - case 1: /* True: equal */ - return 0; - case 2: /* NotImplemented */ - break; - } - return 2; /* XXX Even if all three returned FALSE?! */ + + return 2; } /* Try a 3-way comparison, returning an int. Return: @@ -530,9 +518,7 @@ do_cmp(PyObject *v, PyObject *w) return default_3way_compare(v, w); } -PyObject *_PyCompareState_Key; - -/* _PyCompareState_nesting is incremented before calling compare (for +/* compare_nesting is incremented before calling compare (for some types) and decremented on exit. If the count exceeds the nesting limit, enable code to detect circular data structures. */ @@ -541,25 +527,31 @@ PyObject *_PyCompareState_Key; #else #define NESTING_LIMIT 500 #endif -int _PyCompareState_nesting = 0; +static int compare_nesting = 0; static PyObject* get_inprogress_dict(void) { + static PyObject *key; PyObject *tstate_dict, *inprogress; + if (key == NULL) { + key = PyString_InternFromString("cmp_state"); + if (key == NULL) + return NULL; + } + tstate_dict = PyThreadState_GetDict(); if (tstate_dict == NULL) { PyErr_BadInternalCall(); return NULL; } - inprogress = PyDict_GetItem(tstate_dict, _PyCompareState_Key); + inprogress = PyDict_GetItem(tstate_dict, key); if (inprogress == NULL) { inprogress = PyDict_New(); if (inprogress == NULL) return NULL; - if (PyDict_SetItem(tstate_dict, _PyCompareState_Key, - inprogress) == -1) { + if (PyDict_SetItem(tstate_dict, key, inprogress) == -1) { Py_DECREF(inprogress); return NULL; } @@ -656,8 +648,8 @@ PyObject_Compare(PyObject *v, PyObject *w) return 0; vtp = v->ob_type; wtp = w->ob_type; - _PyCompareState_nesting++; - if (_PyCompareState_nesting > NESTING_LIMIT && + compare_nesting++; + if (compare_nesting > NESTING_LIMIT && (vtp->tp_as_mapping || PyInstance_Check(v) || (vtp->tp_as_sequence && !PyString_Check(v)))) { @@ -690,7 +682,7 @@ PyObject_Compare(PyObject *v, PyObject *w) result = do_cmp(v, w); } exit_cmp: - _PyCompareState_nesting--; + compare_nesting--; return result < 0 ? -1 : result; } @@ -738,30 +730,45 @@ PyObject_RichCompare(PyObject *v, PyObject *w, int op) assert(Py_LT <= op && op <= Py_GE); - if (_PyCompareState_nesting > NESTING_LIMIT) { - /* Too deeply nested -- assume equal */ - /* XXX This is an unfair shortcut! - Should use the same logic as PyObject_Compare. */ - switch (op) { - case Py_LT: - case Py_NE: - case Py_GT: - res = Py_False; - break; - case Py_LE: - case Py_EQ: - case Py_GE: - res = Py_True; - break; + compare_nesting++; + if (compare_nesting > NESTING_LIMIT && + (v->ob_type->tp_as_mapping + || PyInstance_Check(v) + || (v->ob_type->tp_as_sequence && !PyString_Check(v)))) { + /* try to detect circular data structures */ + PyObject *inprogress, *pair; + + inprogress = get_inprogress_dict(); + if (inprogress == NULL) { + res = NULL; + goto exit_cmp; } - Py_INCREF(res); - return res; + pair = make_pair(v, w); + if (PyDict_GetItem(inprogress, pair)) { + /* already comparing these objects. assume + they're equal until shown otherwise */ + Py_DECREF(pair); + if (op == Py_EQ || op == Py_LE || op == Py_GE) + res = Py_True; + else + res = Py_False; + Py_INCREF(res); + goto exit_cmp; + } + if (PyDict_SetItem(inprogress, pair, pair) == -1) { + res = NULL; + goto exit_cmp; + } + res = do_richcmp(v, w, op); + /* XXX DelItem shouldn't fail */ + PyDict_DelItem(inprogress, pair); + Py_DECREF(pair); } - - _PyCompareState_nesting++; - res = do_richcmp(v, w, op); - _PyCompareState_nesting--; - + else { + res = do_richcmp(v, w, op); + } + exit_cmp: + compare_nesting--; return res; }