Add a test suite for WeakSet mostly derived from test_set and fix some

issues in the weakset implementation discovered with it.
This commit is contained in:
Georg Brandl 2008-05-18 16:27:29 +00:00
parent c4dc0d4000
commit 9dba5d9764
2 changed files with 353 additions and 12 deletions

View File

@ -23,6 +23,9 @@ def __iter__(self):
if item is not None:
yield item
def __len__(self):
return sum(x() is not None for x in self.data)
def __contains__(self, item):
return ref(item) in self.data
@ -61,7 +64,9 @@ def update(self, other):
else:
for element in other:
self.add(element)
__ior__ = update
def __ior__(self, other):
self.update(other)
return self
# Helper functions for simple delegating methods.
def _apply(self, other, method):
@ -72,43 +77,68 @@ def _apply(self, other, method):
newset.data = newdata
return newset
def _apply_mutate(self, other, method):
if not isinstance(other, self.__class__):
other = self.__class__(other)
method(other)
def difference(self, other):
return self._apply(other, self.data.difference)
__sub__ = difference
def difference_update(self, other):
self._apply_mutate(self, self.data.difference_update)
__isub__ = difference_update
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other):
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
return self
def intersection(self, other):
return self._apply(other, self.data.intersection)
__and__ = intersection
def intersection_update(self, other):
self._apply_mutate(self, self.data.intersection_update)
__iand__ = intersection_update
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other):
self.data.intersection_update(ref(item) for item in other)
return self
def issubset(self, other):
return self.data.issubset(ref(item) for item in other)
__lt__ = issubset
def __le__(self, other):
return self.data <= set(ref(item) for item in other)
def issuperset(self, other):
return self.data.issuperset(ref(item) for item in other)
__gt__ = issuperset
def __ge__(self, other):
return self.data >= set(ref(item) for item in other)
def __eq__(self, other):
return self.data == set(ref(item) for item in other)
def symmetric_difference(self, other):
return self._apply(other, self.data.symmetric_difference)
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
self._apply_mutate(other, self.data.symmetric_difference_update)
__ixor__ = symmetric_difference_update
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other):
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
return self
def union(self, other):
return self._apply(other, self.data.union)
__or__ = union
def isdisjoint(self, other):
return len(self.intersection(other)) == 0

311
Lib/test/test_weakset.py Normal file
View File

@ -0,0 +1,311 @@
import unittest
from test import test_support
from weakref import proxy, ref, WeakSet
import operator
import copy
import string
import os
from random import randrange, shuffle
import sys
import warnings
import collections
from collections import UserString as ustr
class Foo:
pass
class TestWeakSet(unittest.TestCase):
def setUp(self):
# need to keep references to them
self.items = [ustr(c) for c in ('a', 'b', 'c')]
self.items2 = [ustr(c) for c in ('x', 'y', 'z')]
self.letters = [ustr(c) for c in string.ascii_letters]
self.s = WeakSet(self.items)
self.d = dict.fromkeys(self.items)
self.obj = ustr('F')
self.fs = WeakSet([self.obj])
def test_methods(self):
weaksetmethods = dir(WeakSet)
for method in dir(set):
if method.startswith('_'):
continue
self.assert_(method in weaksetmethods)
def test_new_or_init(self):
self.assertRaises(TypeError, WeakSet, [], 2)
def test_len(self):
self.assertEqual(len(self.s), len(self.d))
self.assertEqual(len(self.fs), 1)
del self.obj
self.assertEqual(len(self.fs), 0)
def test_contains(self):
for c in self.letters:
self.assertEqual(c in self.s, c in self.d)
self.assertRaises(TypeError, self.s.__contains__, [[]])
self.assert_(self.obj in self.fs)
del self.obj
self.assert_(ustr('F') not in self.fs)
def test_union(self):
u = self.s.union(self.items2)
for c in self.letters:
self.assertEqual(c in u, c in self.d or c in self.items2)
self.assertEqual(self.s, WeakSet(self.items))
self.assertEqual(type(u), WeakSet)
self.assertRaises(TypeError, self.s.union, [[]])
for C in set, frozenset, dict.fromkeys, list, tuple:
x = WeakSet(self.items + self.items2)
c = C(self.items2)
self.assertEqual(self.s.union(c), x)
def test_or(self):
i = self.s.union(self.items2)
self.assertEqual(self.s | set(self.items2), i)
self.assertEqual(self.s | frozenset(self.items2), i)
def test_intersection(self):
i = self.s.intersection(self.items2)
for c in self.letters:
self.assertEqual(c in i, c in self.d and c in self.items2)
self.assertEqual(self.s, WeakSet(self.items))
self.assertEqual(type(i), WeakSet)
for C in set, frozenset, dict.fromkeys, list, tuple:
x = WeakSet([])
self.assertEqual(self.s.intersection(C(self.items2)), x)
def test_isdisjoint(self):
self.assert_(self.s.isdisjoint(WeakSet(self.items2)))
self.assert_(not self.s.isdisjoint(WeakSet(self.letters)))
def test_and(self):
i = self.s.intersection(self.items2)
self.assertEqual(self.s & set(self.items2), i)
self.assertEqual(self.s & frozenset(self.items2), i)
def test_difference(self):
i = self.s.difference(self.items2)
for c in self.letters:
self.assertEqual(c in i, c in self.d and c not in self.items2)
self.assertEqual(self.s, WeakSet(self.items))
self.assertEqual(type(i), WeakSet)
self.assertRaises(TypeError, self.s.difference, [[]])
def test_sub(self):
i = self.s.difference(self.items2)
self.assertEqual(self.s - set(self.items2), i)
self.assertEqual(self.s - frozenset(self.items2), i)
def test_symmetric_difference(self):
i = self.s.symmetric_difference(self.items2)
for c in self.letters:
self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
self.assertEqual(self.s, WeakSet(self.items))
self.assertEqual(type(i), WeakSet)
self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
def test_xor(self):
i = self.s.symmetric_difference(self.items2)
self.assertEqual(self.s ^ set(self.items2), i)
self.assertEqual(self.s ^ frozenset(self.items2), i)
def test_sub_and_super(self):
pl, ql, rl = map(lambda s: [ustr(c) for c in s], ['ab', 'abcde', 'def'])
p, q, r = map(WeakSet, (pl, ql, rl))
self.assert_(p < q)
self.assert_(p <= q)
self.assert_(q <= q)
self.assert_(q > p)
self.assert_(q >= p)
self.failIf(q < r)
self.failIf(q <= r)
self.failIf(q > r)
self.failIf(q >= r)
self.assert_(set('a').issubset('abc'))
self.assert_(set('abc').issuperset('a'))
self.failIf(set('a').issubset('cbs'))
self.failIf(set('cbs').issuperset('a'))
def test_gc(self):
# Create a nest of cycles to exercise overall ref count check
class A:
pass
s = set(A() for i in range(1000))
for elem in s:
elem.cycle = s
elem.sub = elem
elem.set = set([elem])
def test_subclass_with_custom_hash(self):
# Bug #1257731
class H(WeakSet):
def __hash__(self):
return int(id(self) & 0x7fffffff)
s=H()
f=set()
f.add(s)
self.assert_(s in f)
f.remove(s)
f.add(s)
f.discard(s)
def test_init(self):
s = WeakSet()
s.__init__(self.items)
self.assertEqual(s, self.s)
s.__init__(self.items2)
self.assertEqual(s, WeakSet(self.items2))
self.assertRaises(TypeError, s.__init__, s, 2);
self.assertRaises(TypeError, s.__init__, 1);
def test_constructor_identity(self):
s = WeakSet(self.items)
t = WeakSet(s)
self.assertNotEqual(id(s), id(t))
def test_set_literal(self):
s = set([1,2,3])
t = {1,2,3}
self.assertEqual(s, t)
def test_hash(self):
self.assertRaises(TypeError, hash, self.s)
def test_clear(self):
self.s.clear()
self.assertEqual(self.s, set())
self.assertEqual(len(self.s), 0)
def test_copy(self):
dup = self.s.copy()
self.assertEqual(self.s, dup)
self.assertNotEqual(id(self.s), id(dup))
def test_add(self):
x = ustr('Q')
self.s.add(x)
self.assert_(x in self.s)
dup = self.s.copy()
self.s.add(x)
self.assertEqual(self.s, dup)
self.assertRaises(TypeError, self.s.add, [])
self.fs.add(Foo())
self.assert_(len(self.fs) == 1)
self.fs.add(self.obj)
self.assert_(len(self.fs) == 1)
def test_remove(self):
x = ustr('a')
self.s.remove(x)
self.assert_(x not in self.s)
self.assertRaises(KeyError, self.s.remove, x)
self.assertRaises(TypeError, self.s.remove, [])
def test_discard(self):
a, q = ustr('a'), ustr('Q')
self.s.discard(a)
self.assert_(a not in self.s)
self.s.discard(q)
self.assertRaises(TypeError, self.s.discard, [])
def test_pop(self):
for i in range(len(self.s)):
elem = self.s.pop()
self.assert_(elem not in self.s)
self.assertRaises(KeyError, self.s.pop)
def test_update(self):
retval = self.s.update(self.items2)
self.assertEqual(retval, None)
for c in (self.items + self.items2):
self.assert_(c in self.s)
self.assertRaises(TypeError, self.s.update, [[]])
def test_update_set(self):
self.s.update(set(self.items2))
for c in (self.items + self.items2):
self.assert_(c in self.s)
def test_ior(self):
self.s |= set(self.items2)
for c in (self.items + self.items2):
self.assert_(c in self.s)
def test_intersection_update(self):
retval = self.s.intersection_update(self.items2)
self.assertEqual(retval, None)
for c in (self.items + self.items2):
if c in self.items2 and c in self.items:
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
self.assertRaises(TypeError, self.s.intersection_update, [[]])
def test_iand(self):
self.s &= set(self.items2)
for c in (self.items + self.items2):
if c in self.items2 and c in self.items:
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
def test_difference_update(self):
retval = self.s.difference_update(self.items2)
self.assertEqual(retval, None)
for c in (self.items + self.items2):
if c in self.items and c not in self.items2:
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
self.assertRaises(TypeError, self.s.difference_update, [[]])
self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
def test_isub(self):
self.s -= set(self.items2)
for c in (self.items + self.items2):
if c in self.items and c not in self.items2:
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
def test_symmetric_difference_update(self):
retval = self.s.symmetric_difference_update(self.items2)
self.assertEqual(retval, None)
for c in (self.items + self.items2):
if (c in self.items) ^ (c in self.items2):
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
def test_ixor(self):
self.s ^= set(self.items2)
for c in (self.items + self.items2):
if (c in self.items) ^ (c in self.items2):
self.assert_(c in self.s)
else:
self.assert_(c not in self.s)
def test_inplace_on_self(self):
t = self.s.copy()
t |= t
self.assertEqual(t, self.s)
t &= t
self.assertEqual(t, self.s)
t -= t
self.assertEqual(t, WeakSet())
t = self.s.copy()
t ^= t
self.assertEqual(t, WeakSet())
def test_main(verbose=None):
test_support.run_unittest(TestWeakSet)
if __name__ == "__main__":
test_main(verbose=True)