IndexedSet getting much closer

This commit is contained in:
Mahmoud Hashemi 2013-03-13 00:18:02 -07:00
parent 066f8ce3db
commit 9716df166f
1 changed files with 114 additions and 29 deletions

View File

@ -1,16 +1,31 @@
# -*- coding: utf-8 -*-
from bisect import bisect_left, insort
from itertools import ifilter
from itertools import ifilter, chain
from collections import MutableSet
_MISSING = object()
# TODO: .sort(), .reverse()
# TODO: slicing
# TODO: in-place set operations
# TODO: better exception messages
class IndexedSet(object):
"""
class IndexedSet(MutableSet):
"""\
IndexedSet maintains insertion order and uniqueness of inserted
elements. It's a hybrid type, mostly like an OrderedSet, but also
list-like, in that it supports indexing and slicing.
>>> IndexedSet(range(4) + range(6))
[0, 1, 2, 3, 4, 5]
>>> x = IndexedSet(range(4) + range(8))
>>> x
IndexedSet([0, 1, 2, 3, 4, 5, 6, 7])
>>> x - set(range(2))
IndexedSet([2, 3, 4, 5, 6, 7])
>>> x[-1]
7
"""
def __init__(self, other=None):
self.item_index_map = dict()
@ -60,9 +75,16 @@ class IndexedSet(object):
def __iter__(self):
return ifilter(lambda e: e is not _MISSING, iter(self.item_list))
def __reversed__(self):
return ifilter(lambda e: e is not _MISSING, reversed(self.item_list))
def __repr__(self):
cn = self.__class__.__name__
return '%s(%r)' % (cn, list(self))
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, IndexedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
#set operations
def remove(self, item): # O(1) + (amortized O(n) cull)
@ -81,40 +103,103 @@ class IndexedSet(object):
pass
def add(self, item):
self.item_index_map[item] = len(self.item_list)
self.item_list.append(item)
if item not in self.item_index_map:
self.item_index_map[item] = len(self.item_list)
self.item_list.append(item)
def update(self, other): # O(n)
for item in other:
self.add(item)
def update(self, other):
for o in other:
self.add(o)
#TODO: a bunch of set operators
#general scheme: add all of the "others" items to the right
def isdisjoint(self, other):
iim = self.item_index_map
for k in other:
if k in iim:
return False
return True
def issubset(self, other):
if len(other) < len(self):
return False
for k in self.item_index_map:
if k not in other:
return False
return True
def issuperset(self, other):
if len(other) > len(self):
return False
iim = self.item_index_map
for k in other:
if k not in iim:
return False
return True
def union(self, *others):
return IndexedSet(chain(self, *others))
def intersection(self, *others):
if len(others) == 1:
other = others[0]
return IndexedSet(k for k in self if k in other)
ret = IndexedSet()
for k in self:
for other in others:
if k not in other:
break
else:
ret.add(k)
return ret
def difference(self, *others):
if len(others) == 1:
other = others[0]
return IndexedSet(k for k in self if k not in other)
ret = IndexedSet()
for k in self:
for other in others:
if k in other:
break
else:
ret.add(k)
return ret
def symmetric_difference(self, *others):
ret = self.union(*others)
return ret.difference(self.intersection(*others))
__or__ = union
__and__ = intersection
__sub__ = difference
__xor__ = symmetric_difference
#list operations
def __getitem__(self, key):
if key < 0:
key += len(self)
phy_key = bisect_left(self.dead_indices, key)
def __getitem__(self, index): # TODO: support slicing
if index < 0:
index += len(self)
real_index = index + bisect_left(self.dead_indices, index)
try:
return self.item_list[phy_key]
return self.item_list[real_index]
except IndexError:
raise #TODO: message
def pop(self, index=None): # O(1) + (amortized O(n) cull)
if index is None or index == -1 or index == len(self):
item_index_map = self.item_index_map
len_self = len(item_index_map)
if index is None or index == -1 or index == len_self:
ret = self.item_list.pop()
del self.item_index_map[ret]
del item_index_map[ret]
else:
if index < 0:
index += len(self) # TODO: not len(self) (extra fxn call)?
phy_index = index + bisect_left(self.dead_indices, index)
insort(self.dead_indices, phy_index)
del self.items[self.item_list[phy_index]]
self.item_list[phy_index] = _MISSING
index += len_self
real_index = index + bisect_left(self.dead_indices, index)
insort(self.dead_indices, real_index)
del item_index_map[self.item_list[real_index]]
self.item_list[real_index] = _MISSING
self._cull()
return ret
_MISSING = object()
def count(self, x):
if x in self.item_index_map:
return 1
return 0