From 9716df166f50802eda1cb247c689322205b014f3 Mon Sep 17 00:00:00 2001 From: Mahmoud Hashemi Date: Wed, 13 Mar 2013 00:18:02 -0700 Subject: [PATCH] IndexedSet getting much closer --- boltons/setutils.py | 143 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 114 insertions(+), 29 deletions(-) diff --git a/boltons/setutils.py b/boltons/setutils.py index f7b5bc5..2b63749 100644 --- a/boltons/setutils.py +++ b/boltons/setutils.py @@ -1,16 +1,31 @@ +# -*- coding: utf-8 -*- + from bisect import bisect_left, insort -from itertools import ifilter +from itertools import ifilter, chain +from collections import MutableSet + +_MISSING = object() + +# TODO: .sort(), .reverse() +# TODO: slicing +# TODO: in-place set operations +# TODO: better exception messages -class IndexedSet(object): - """ +class IndexedSet(MutableSet): + """\ IndexedSet maintains insertion order and uniqueness of inserted elements. It's a hybrid type, mostly like an OrderedSet, but also list-like, in that it supports indexing and slicing. - >>> IndexedSet(range(4) + range(6)) - [0, 1, 2, 3, 4, 5] + >>> x = IndexedSet(range(4) + range(8)) + >>> x + IndexedSet([0, 1, 2, 3, 4, 5, 6, 7]) + >>> x - set(range(2)) + IndexedSet([2, 3, 4, 5, 6, 7]) + >>> x[-1] + 7 """ def __init__(self, other=None): self.item_index_map = dict() @@ -60,9 +75,16 @@ class IndexedSet(object): def __iter__(self): return ifilter(lambda e: e is not _MISSING, iter(self.item_list)) + def __reversed__(self): + return ifilter(lambda e: e is not _MISSING, reversed(self.item_list)) + def __repr__(self): - cn = self.__class__.__name__ - return '%s(%r)' % (cn, list(self)) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, IndexedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) #set operations def remove(self, item): # O(1) + (amortized O(n) cull) @@ -81,40 +103,103 @@ class IndexedSet(object): pass def add(self, item): - self.item_index_map[item] = len(self.item_list) - self.item_list.append(item) + if item not in self.item_index_map: + self.item_index_map[item] = len(self.item_list) + self.item_list.append(item) - def update(self, other): # O(n) - for item in other: - self.add(item) + def update(self, other): + for o in other: + self.add(o) - #TODO: a bunch of set operators - #general scheme: add all of the "others" items to the right + def isdisjoint(self, other): + iim = self.item_index_map + for k in other: + if k in iim: + return False + return True + + def issubset(self, other): + if len(other) < len(self): + return False + for k in self.item_index_map: + if k not in other: + return False + return True + + def issuperset(self, other): + if len(other) > len(self): + return False + iim = self.item_index_map + for k in other: + if k not in iim: + return False + return True + + def union(self, *others): + return IndexedSet(chain(self, *others)) + + def intersection(self, *others): + if len(others) == 1: + other = others[0] + return IndexedSet(k for k in self if k in other) + ret = IndexedSet() + for k in self: + for other in others: + if k not in other: + break + else: + ret.add(k) + return ret + + def difference(self, *others): + if len(others) == 1: + other = others[0] + return IndexedSet(k for k in self if k not in other) + ret = IndexedSet() + for k in self: + for other in others: + if k in other: + break + else: + ret.add(k) + return ret + + def symmetric_difference(self, *others): + ret = self.union(*others) + return ret.difference(self.intersection(*others)) + + __or__ = union + __and__ = intersection + __sub__ = difference + __xor__ = symmetric_difference #list operations - def __getitem__(self, key): - if key < 0: - key += len(self) - phy_key = bisect_left(self.dead_indices, key) + def __getitem__(self, index): # TODO: support slicing + if index < 0: + index += len(self) + real_index = index + bisect_left(self.dead_indices, index) try: - return self.item_list[phy_key] + return self.item_list[real_index] except IndexError: raise #TODO: message def pop(self, index=None): # O(1) + (amortized O(n) cull) - if index is None or index == -1 or index == len(self): + item_index_map = self.item_index_map + len_self = len(item_index_map) + if index is None or index == -1 or index == len_self: ret = self.item_list.pop() - del self.item_index_map[ret] + del item_index_map[ret] else: if index < 0: - index += len(self) # TODO: not len(self) (extra fxn call)? - phy_index = index + bisect_left(self.dead_indices, index) - insort(self.dead_indices, phy_index) - del self.items[self.item_list[phy_index]] - self.item_list[phy_index] = _MISSING + index += len_self + real_index = index + bisect_left(self.dead_indices, index) + insort(self.dead_indices, real_index) + del item_index_map[self.item_list[real_index]] + self.item_list[real_index] = _MISSING self._cull() return ret - - -_MISSING = object() + def count(self, x): + if x in self.item_index_map: + return 1 + return 0