diff --git a/boltons/treeutils.py b/boltons/treeutils.py index c394617..e363487 100644 --- a/boltons/treeutils.py +++ b/boltons/treeutils.py @@ -28,7 +28,7 @@ class Tree(object): else: cur = cur[2] cur = stack[-1] - if item <= cur[0]: + if item < cur[0]: cur[1] = [item, None, None, 1] else: cur[2] = [item, None, None, 1] @@ -73,79 +73,52 @@ class Tree(object): node = stack[i] left, right = node[1], node[2] height = max(0, left and left[3], right and right[3]) + 1 - if height == node[3]: - # if we have not changed heights, we're done rotating - return + #if height == node[3]: + # # if we have not changed heights, we're done rotating + # return + print 'i: %s, height: %s, new height: %s' % (i, node[3], height) node[3] = height while 1: - balance = (left and left[3] or 0) - (right and right[3] or 0) + balance = (node[1] and node[1][3] or 0) - (node[2] and node[2][3] or 0) if abs(balance) < 2: break # we're balanced - if balance > 1: - leftleft, leftright = left and (left[1], left[2]) or (None, None) - leftbalance = (leftleft and leftleft[3] or 0) - (leftright and leftright[3] or 0) - if leftbalance < 0: # left rotate left - node[1] = leftright # left = leftright - left[2] = leftright[1] # left.right = leftright.left - leftright[1] = left # leftright.left = left - left = node[1] # set left to new value for next rotate - #right rotate - if i > 0: # right rotate around parent - parent = stack[i - 1] - if parent[1] is node: - parent[1] = left - elif parent[2] is node: - parent[2] = left - else: # right rotate around root - self.root = left - node[1] = left[2] # right = leftright - left[2] = node # leftright = node + rel_side = balance / abs(balance) # -1: rotate left + side, other_side = (balance < 0) + 1, (balance > 0) + 1 + child = node[side] + cbal = (child[1] and child[1][3] or 0) - (child[2] and child[2][3] or 0) - node = parent if i > 0 else self.root #set node to (new) correct value for balance + if (rel_side * cbal) < 0: + grandchild = child[other_side] + node[side] = grandchild + child[other_side] = grandchild[side] + grandchild[side] = child + gc = grandchild + gc[3] = max(0, gc[1] and gc[1][3], gc[2] and gc[2][3]) + 1 + child = node[side] # we're done with the old child - if balance < -1: - rightleft = right and right[1] - rightright = right and right[2] - rightbalance = (rightleft and rightleft[3] or 0) - (rightright and rightright[3] or 0) - if rightbalance > 0: #right rotate right - node[2] = rightleft #right = rightleft - right[1] = rightleft[2] #right.left = rightleft.right - rightleft[2] = right #rightleft.right = right - right = node[2] #set right to (new) correct value for next rotation - #rotate left - if i > 0: #left rotate around parent - parent = stack[i-1] - if parent[1] is node: - parent[1] = right - elif parent[2] is node: - parent[2] = right - else: #left rotate around root - self.root = right - node[2] = right[1] #right = rightleft - right[1] = node #rightleft = node + if i == 0: + self.root = child + else: + parent = stack[i - 1] + if parent[1] is node: + parent[1] = child + if parent[2] is node: + parent[2] = child + node[side] = child[other_side] + ns = node[side] + if ns: + ns[3] = max(ns[1] and ns[1][3], ns[2] and ns[2][3], 0) + 1 - node = parent if i > 0 else self.root #set node to (new) correct value for balance + child[other_side] = node + cos = child[other_side] + if cos: + cos[3] = max(cos[1] and cos[1][3], cos[2] and cos[2][3], 0) + 1 - left = node[1] #update left and right to (new) correct value for balance - right = node[2] - if left: - for c in (1,2): - cur = left[c] - if cur: - cur[3] = max(cur[1] and cur[1][3] or 0, - cur[2] and cur[2][3] or 0) + 1 - left[3] = max(left[1] and left[1][3] or 0, left[2] and left[2][3] or 0) + 1 - if right: - for c in (1,2): - cur = right[c] - if cur: - cur[3] = max(cur[1] and cur[1][3] or 0, - cur[2] and cur[2][3] or 0) + 1 - right[3] = max(right[1] and right[1][3] or 0, right[2] and right[2][3] or 0) + 1 - node[3] = max(left and left[3] or 0, right and right[3] or 0) + 1 + node = parent if i > 0 else self.root + node[3] = max(node[1] and node[1][3], node[2] and node[2][3], 0) + 1 # continue inner while # continue outer for - + return def __iter__(self): cur = self.root @@ -171,23 +144,50 @@ class Tree(object): def test(): import os - #vals = [ord(x) for x in os.urandom(2048)] - vals = range(2048) - tree = Tree() - for v in vals: + vals = [ord(x) for x in os.urandom(2048)] + #vals = range(2048) + #vals = sorted(range(100) * 20) + #vals = list(reversed(range(100))) + tree, old_tree = Tree(), Tree() + sorted_vals = sorted(vals) + offset = 0 + treed_vals = None + for i, v in enumerate(vals): + print i, v tree.insert(v) - sorted_vals = list(tree) - import pdb;pdb.set_trace() - return sorted(vals) == sorted_vals + treed_vals = list(tree) + new_offset = len(treed_vals) - (i + 1) + if new_offset != offset: # sorted_vals[:i]: + offset = new_offset + import pdb;pdb.set_trace() + old_tree.insert(v) + + #treed_vals = list(tree) + print len(vals), len(treed_vals) + return sorted(vals) == sorted_vals if __name__ == '__main__': import signal, pdb def pdb_int_handler(sig, frame): pdb.set_trace() signal.signal(signal.SIGINT, pdb_int_handler) - res = test() + try: + res = test() + except: + import pdb;pdb.post_mortem() + raise if res: print 'tests pass' else: print 'tests failed' + + + +""" +[156, [45, [20, None, None, 1], [148, [91, None, None, 1], None, 2], 3], [212, [159, None, None, 1], [240, None, None, 1], 2], 4] + ++ 136 + += [156, [45, [20, None, None, 1], None, 2], [212, [159, None, None, 1], [240, None, None, 1], 2], 3] +"""