AVL樹

# -*- encoding:utf8 -*-

# ### 平衡二叉樹 ###
import BinarySortTree

class Node(BinarySortTree.TreeNode):
	def __init__(self, key):
		BinarySortTree.TreeNode.__init__(self, key)
		self.height = 0

class AVLTree(BinarySortTree.BinaryTree):
	def __init__(self, root):
		if not isinstance(root, Node):
			root = Node(root)
		self.root = root

	def height(self, node):
		if node is None:
			return -1
		else:
			return node.height

	def singleLeftRotate(self, node):
		k1 = node.left
		node.left = k1.right
		k1.right = node
		node.height = max(self.height(node.right), self.height(node.left)) + 1
		k1.height = max(self.height(k1.left), node.height) + 1
		return k1

	def singleRightRotate(self, node):
		k1 = node.right
		node.right = k1.left
		k1.left = node
		node.height = max(self.height(node.right), self.height(node.left)) + 1
		k1.height = max(self.height(k1.right), node.height) + 1
		return k1

	def doubleLeftRotate(self, node):
		node.left = self.singleRightRotate(node.left)
		return self.singleLeftRotate(node)

	def doubleRightRotate(self, node):
		node.right = self.singleLeftRotate(node.right)
		return self.singleRightRotate(node)

	def _put(self, value, node):
		if node is None:
			node = Node(value)
		elif value < node.key:
			node.left = self._put(value, node.left)
			if self.height(node.left) - self.height(node.right) == 2:
				if self.height(node.left.left) >= self.height(node.left.right):
					node = self.singleLeftRotate(node)
				else:
					node = self.doubleLeftRotate(node)

		elif value > node.key:
			node.right = self._put(value, node.right)
			if self.height(node.right) - self.height(node.left) == 2:
				if self.height(node.right.right) >= self.height(node.right.left):
					node = self.singleRightRotate(node)
				else:
					node = self.doubleRightRotate(node)

		node.height = max(self.height(node.right), self.height(node.left)) + 1
		return node

	def insert(self, value):
		self.root = self._put(value, self.root)

	def delete(self, value):
		self.root = self._remove(value, self.root)

	def _remove(self, value, node):
		if node is None:
			raise ValueError("沒有該節點")
		if node.key == value:
			if self.height(node.left) < 0:
				node = node.right
				return node
			elif self.height(node.left) >= 0 and self.height(node.right) < 0:
				node = node.left
				return node
			elif self.height(node.left) <= self.height(node.right):
				node = self.singleLeftRotate(node)
			else:
				node = self.singleRightRotate(node)
			node = self._remove(value, node)
		elif node.key < value:
			node.right = self._remove(value, node.right)
		else:
			node.left = self._remove(value, node.left)

		if node and self.height(node.left) - self.height(node.right) == 2:
			if self.height(node.left.left) >= self.height(node.left.right):
				node = self.singleLeftRotate(node)
			else:
				node = self.doubleLeftRotate(node)
		if node and self.height(node.right) - self.height(node.left) == 2:
			if self.height(node.right.right) - self.height(node.right.left):
				node = self.singleRightRotate(node)
			else:
				node = self.doubleRightRotate(node)

		if node:
			node.height = max(self.height(node.left), self.height(node.right)) + 1

		return node

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章