diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index d81ecf59c49..60f3f402a95 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -458,6 +458,23 @@ def __eq__(self, other): self.assertRaises(RuntimeError, lambda: d2.items() < d3.items()) self.assertRaises(RuntimeError, lambda: d3.items() > d2.items()) + def test_dictview_set_operations(self): + k1 = {1:1, 2:2}.keys() + k2 = {1:1, 2:2, 3:3}.keys() + k3 = {4:4}.keys() + + self.assertEquals(k1 - k2, set()) + self.assertEquals(k1 - k3, {1,2}) + self.assertEquals(k2 - k1, {3}) + self.assertEquals(k3 - k1, {4}) + self.assertEquals(k1 & k2, {1,2}) + self.assertEquals(k1 & k3, set()) + self.assertEquals(k1 | k2, {1,2,3}) + self.assertEquals(k1 ^ k2, {3}) + self.assertEquals(k1 ^ k3, {1,2,4}) + + # XXX similar tests for .items() + def test_missing(self): # Make sure dict doesn't have a __missing__ method self.assertEqual(hasattr(dict, "__missing__"), False) diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 9ef1fcc340f..539d73418a6 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2489,6 +2489,98 @@ static PySequenceMethods dictkeys_as_sequence = { (objobjproc)dictkeys_contains, /* sq_contains */ }; +static PyObject* +dictviews_sub(PyObject* self, PyObject *other) +{ + PyObject *result = PySet_New(self); + PyObject *tmp; + if (result == NULL) + return NULL; + + tmp = PyObject_CallMethod(result, "difference_update", "O", other); + if (tmp == NULL) { + Py_DECREF(result); + return NULL; + } + + Py_DECREF(tmp); + return result; +} + +static PyObject* +dictviews_and(PyObject* self, PyObject *other) +{ + PyObject *result = PySet_New(self); + PyObject *tmp; + if (result == NULL) + return NULL; + + tmp = PyObject_CallMethod(result, "intersection_update", "O", other); + if (tmp == NULL) { + Py_DECREF(result); + return NULL; + } + + Py_DECREF(tmp); + return result; +} + +static PyObject* +dictviews_or(PyObject* self, PyObject *other) +{ + PyObject *result = PySet_New(self); + PyObject *tmp; + if (result == NULL) + return NULL; + + tmp = PyObject_CallMethod(result, "update", "O", other); + if (tmp == NULL) { + Py_DECREF(result); + return NULL; + } + + Py_DECREF(tmp); + return result; +} + +static PyObject* +dictviews_xor(PyObject* self, PyObject *other) +{ + PyObject *result = PySet_New(self); + PyObject *tmp; + if (result == NULL) + return NULL; + + tmp = PyObject_CallMethod(result, "symmetric_difference_update", "O", + other); + if (tmp == NULL) { + Py_DECREF(result); + return NULL; + } + + Py_DECREF(tmp); + return result; +} + +static PyNumberMethods dictviews_as_number = { + 0, /*nb_add*/ + (binaryfunc)dictviews_sub, /*nb_subtract*/ + 0, /*nb_multiply*/ + 0, /*nb_remainder*/ + 0, /*nb_divmod*/ + 0, /*nb_power*/ + 0, /*nb_negative*/ + 0, /*nb_positive*/ + 0, /*nb_absolute*/ + 0, /*nb_bool*/ + 0, /*nb_invert*/ + 0, /*nb_lshift*/ + 0, /*nb_rshift*/ + (binaryfunc)dictviews_and, /*nb_and*/ + (binaryfunc)dictviews_xor, /*nb_xor*/ + (binaryfunc)dictviews_or, /*nb_or*/ +}; + static PyMethodDef dictkeys_methods[] = { {NULL, NULL} /* sentinel */ }; @@ -2505,7 +2597,7 @@ PyTypeObject PyDictKeys_Type = { 0, /* tp_setattr */ 0, /* tp_compare */ 0, /* tp_repr */ - 0, /* tp_as_number */ + &dictviews_as_number, /* tp_as_number */ &dictkeys_as_sequence, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ @@ -2589,7 +2681,7 @@ PyTypeObject PyDictItems_Type = { 0, /* tp_setattr */ 0, /* tp_compare */ 0, /* tp_repr */ - 0, /* tp_as_number */ + &dictviews_as_number, /* tp_as_number */ &dictitems_as_sequence, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */