Binary Search Trees

Posted by Beetle B. on Wed 27 August 2014

A binary search tree is a fairly useful structure for storing ordered data.

The invariant for a binary search tree is straightforward: For each node, all the elements on the left are smaller, and all the elements on the right are larger (assuming all elements are distinct).

As such, it is easy to search for an element in the tree.

Note that the binary search tree need not be well balanced.

Operations

Min/Max

For min, just keep traversing to the left till you can’t go any further.

Floor/Ceiling

Floor means find the largest element smaller than or equal to a given one:

  1. If the value is in the tree, that is the floor.
  2. If the current node is bigger, the floor is in the left subtree. Recurse.
  3. If the current node is smaller, the floor is in the right subtree only if the right subtree has an element smaller than the value. If not, then it is the current node.

Insertion

To insert an element into the tree, we “search” for it. Once we get to a node where we can’t go any further, we insert it in the appropriate place. However, we need to update the parent to point to the new child. One way to do it is to make the function recursive, and have it return a pointer to itself, updating the node as we go up back the tree (function stack). This way we “rebuild” the tree (although only one change has been made effectively).

In-Order Traversal

The way to do this is to traverse the tree recursively, carrying along with it a queue. It traverses, the left of the tree, then inserts the current value in the queue, and traverses the right. Whenever we’re at a leaf node, we insert the value into the queue.

Deletion

Deletion is nontrivial.

To delete the minimal element, traverse the tree to the left till you cannot go any further, and update the parent’s link to it to point to the minimal element’s right node (if any). This is straightforward.

To delete an arbitrary element, one approach is to use Hibbard deletion:

  1. If you have only one child, replace the node with the child (and update the parent).
  2. If you have 2 children, it’s not obvious what to replace it with. We can’t necessarily just replace it with one of its children, as the child may have 2 children itself (resulting in a node with three children). So we find the next largest element on its right subtree (i.e. the minimal element in the right subtree), and replace it with that. Why the next largest element? To make sure the invariant on the right side still holds.

The problem with Hibbard deletion is that after many inserts and deletes, the tree typically becomes unbalanced. This results in a tree with height \(O(\sqrt{N})\), which is quite different compared to the usual \(O(\lg N)\).

It’s still an open problem to find a good deletion scheme.

Properties

  • If you think about it, a binary tree is just like quicksort, where a node is the pivot and you’re dividing into smaller and larger elements on each side. Much of the analysis one performs on quicksort applies to binary search trees as well.
  • If you insert distinct keys in random order, the expected height is about \(4.311\ln N\).
  • In each node, it is useful to store the number of elements beneath it (on both left and right). One then needs to update these numbers if an insertion/deletion occurs.

Implementations

Python

from Queue import Queue

class Node:

    def __init__(self, key, value):
	self.key = key
	self.value = value
	self.left = None
	self.right = None
	self.size = 1

    def get_min(self):
	""" Return the node with the min key. """
	if not self.left:
	    return self
	node = self
	while node.left:
	    node = node.left
	return node

    def get_max(self):
	""" Return the node with the max key. """
	if not self.right:
	    return self
	node = self
	while node.right:
	    node = node.right
	return node

def size(node):
    if node:
	return node.size
    else:
	return 0

class BinarySearchTree:

    def __init__(self, root):
	self.root = root

    def insert(self, key, value):
	self.__doinsert(self.root, key, value)

    def __doinsert(self, node, key, value):
	if not node:
	    return Node(key, value)
	if key < node.key:
	    node.left = self.__doinsert(node.left, key, value)
	    node.size += 1
	if key > node.key:
	    node.right = self.__doinsert(node.right, key, value)
	    node.size += 1
	else:
	    node.value = value
	return node

    def delete(self, key):
	self.root = self.__dodelete(self.root, key)
	
    def __dodelete(self, node, key):
	"""
	Once you are at the node you want to delete, find the successor
	in the right subtree, and replace the current node with the
	successor. Return the successor so that the parent's link is
	updated.
	"""
	if not node:
	    return None
	if key < node.key:
	    node.left = self.__dodelete(node.left, key)
	elif key > node.key:
	    node.right = self.__dodelete(node.right, key)
	else:
	    if not node.right:
		return node.left
	    if not node.left:
		return node.right
	    tmp = node
	    node = tmp.right.get_min()
	    node.right = self.__deletemin(tmp.right)
	    node.left = tmp.left
	node.size = 1 + size(node.left) + size(node.right)
	return node

    def __deletemin(self, node):
	"""
	Delete the minimum key in the subtree rooted at node.

	If the node has a smaller element, it will just return itself
	after deleting the minimum. As in, the deleted (or 'unlinked')
	node may be far away, but the return value is simply the current
	node.

	If the node has no smaller element, it returns its right node
	(if any).

	If the node has no children, it returns None (the node is simply
	deleted).
	"""
	if not node.left:
	    return node.right
	node.left = self.__deletemin(node.left)
	node.size = 1 + size(node.left) + size(node.right)
	return node

    def search(self, key):
	"""
	Searches the tree for a node with that key. If it finds one, it
	returns the node.
	"""
	node = self.root
	while node:
	    if key == node.key:
		return node
	    if key < node.key:
		node = node.left
	    else:
		node = node.right
	return node

    def __getfloor(self, node, key):
	if not node:
	    return None
	if node.key == key:
	    return node
	if key < node.key:
	    return self.__getfloor(node.left, key)
	if key > node.key:
	    n = self.__getfloor(node.right, key)
	    if n: # Found it in the right subtree. 
		return n
	    else: # Nothing smaller in the right subtree.
		return node

    def floor(self, key):
	""" Return the largest key smaller than or equal to key. """
	return self.__getfloor(self.root, key)

    def ceiling(self, key):
	""" Return the smallest key larger than or equal to key. """
	return self.__getceiling(self.root, key)

    def __getceiling(self, node, key):
	if not node:
	    return None
	if node.key == key:
	    return node
	if key > node.key:
	    return self.__getceiling(node.right, key)
	if key < node.key:
	    n = self.__getceiling(node.left, key)
	    if n: # Found it in the right subtree. 
		return n
	    else: # Nothing smaller in the right subtree.
		return node
	
    def rank(self, k):
	""" How many keys less than k? """
	return self.__getrank(self.root, k)

    def __getrank(self, node, k):
	if not node:
	    return 0
	if k < node.key:
	    return self.__getrank(node.left, k)
	if k > node.key:
	    # 1 is for current node which is smaller
	    return 1 + size(node.left) + self.__getrank(node.right, k)
	if k == node.key:
	    if node.left:
		return node.left.size
	    else:
		return 0
	    
    def traverse(self):
	q = Queue()
	self.__inorder(self.root, q)
	return q

    def __inorder(self, node, queue):
	if not node:
	    return
	self.__inorder(node.left, queue)
	queue.put(node)
	self.__inorder(node.right, queue)

C++

#include <cstddef> // For NULL definition.
#include <queue>

class Node;

class Node 
{
public:
  int key;
  int value;
  Node* left;
  Node* right;
  int size;

  Node(int key, int val);
  Node* get_min();
  Node* get_max();
};

Node::Node(int key, int val) : key(key), value(val), size(1), left(NULL), 
                               right(NULL) {}

Node* Node::get_min()
{
  if(NULL == left)
    {
      return this;
    }
  Node* node = this;
  while (NULL != node->left)
    {
      node = node->left;
    }
  return node;
}

Node* Node::get_max()
{
  if(NULL == right)
    {
      return this;
    }
  Node* node = this;
  while (NULL != node->right)
    {
      node = node->right;
    }
  return node;
}

int size(Node const * node)
{
  if(NULL != node)
    {
      return node->size;
    }
  return 0;
}

class BinarySearchTree
{
public:
  Node* root;
  
  BinarySearchTree(Node* root);
  ~BinarySearchTree();
  void insert(int key, int value);
  void deletekey(int key);
  Node* floor(int key) const;
  Node* ceiling(int key) const;
  Node* search(int key);
  int rank(int key) const;
  std::queue<Node*> traverse() const;

private:
  Node* doinsert(Node* node, int key, int value);
  Node* dodelete(Node* node, int key);
  Node* deletemin(Node* node);
  Node* getfloor(Node* node, int key) const;
  Node* getceiling(Node* node, int key) const;
  int getrank(Node* node, int key) const;
  void inorder(Node* node, std::queue<Node*>& queue) const;
};

BinarySearchTree::BinarySearchTree(Node* root) : root(root) {}

Node* BinarySearchTree::doinsert(Node* node, int key, int value)
{
  if (NULL==node)
    {
      Node* newnode = new Node(key, value);
      return newnode;
    }
  if (key < node->key)
    {
      node->left = doinsert(node->left, key, value);
      ++(node->size);
    }
  if (key > node->key)
    {
      node->right = doinsert(node->right, key, value);
      ++(node->size);
    }
  else
    {
      node->value = value;
    }
  return node;
}

void BinarySearchTree::insert(int key, int value)
{
  doinsert(root, key, value);
}

void BinarySearchTree::deletekey(int key)
{
  root = dodelete(root, key);
}

Node* BinarySearchTree::dodelete(Node* node, int key)
{
  if (NULL==node)
    {
      return NULL;
    }
  if (key < node->key)
    {
      node->left = dodelete(node->left, key);
    }
  else if (key > node->key)
    {
      node->right = dodelete(node->right, key);
    }
  else
    {
      if (NULL == node->right)
        {
          return node->left;
        }
      if (NULL == node->left)
        {
          return node->right;
        }
      Node* tmp = node;
      node = tmp->right->get_min();
      node->right = deletemin(tmp->right);
      node->left = tmp->left;
      delete tmp;
    }
  node->size = 1 + size(node->left) + size(node->right);
  return node;
}

Node* BinarySearchTree::deletemin(Node* node)
{
  if (NULL == node->left)
    {
      return node->right;
    }
  node->left = deletemin(node->left);
  node->size = 1 + size(node->left) + size(node->right);
  return node;
}

Node* BinarySearchTree::search(int key)
{
  Node* node = root;
  while(NULL != node)
    {
      if(key == node->key)
        {
          return node;
        }
      else if(key < node->key)
        {
          node = node->left;
        }
      else if(key > node->key)
        {
          node = node->right;
        }
    }
  return node;
}

Node* BinarySearchTree::getfloor(Node* node, int key) const
{
  if (NULL == node)
    {
      return NULL; 
    }
  if(key == node->key)
    {
      return node;
    }
  if(key < node->key)
    {
      return getfloor(node->left, key);
    }
  if(key > node->key)
    {
      Node* n = getfloor(node->right, key);
      if (NULL != n)
        {
          return n;
        }
      else
        {
          return node;
        }
    }
}

Node* BinarySearchTree::floor(int key) const
{
  return getfloor(root, key);
}

Node* BinarySearchTree::getceiling(Node* node, int key) const
{
  if (NULL == node)
    {
      return NULL; 
    }
  if(key == node->key)
    {
      return node;
    }
  if(key > node->key)
    {
      return getceiling(node->right, key);
    }
  if(key < node->key)
    {
      Node* n = getceiling(node->left, key);
      if (NULL != n)
        {
          return n;
        }
      else
        {
          return node;
        }
    }
}

Node* BinarySearchTree::ceiling(int key) const
{
  return getceiling(root, key);
}

int BinarySearchTree::rank(int key) const
{
  return getrank(root, key);
}

int BinarySearchTree::getrank(Node* node, int key) const
{
  if (NULL == node)
    {
      return 0; 
    }
  if(key < node->key)
    {
      return getrank(node->left, key);
    }
  if(key > node->key)
    {
      return 1 + size(node->left) + getrank(node->right, key);
    }
  if(key == node->key)
    {
      if (NULL != node->left)
        {
          return node->left->size;
        }
      else
        {
          return 0;
        }
    }
}

void BinarySearchTree::inorder(Node* node, std::queue<Node*>& queue) const
{
  if (NULL == node)
    return;
  inorder(node->left, queue);
  queue.push(node);
  inorder(node->right, queue);
}

std::queue<Node*> BinarySearchTree::traverse() const
{
  std::queue<Node*> queue;
  inorder(root, queue);
  return queue;
}

// Not exactly the fastest way of doing things!
BinarySearchTree::~BinarySearchTree()
{
  std::queue<Node*> queue = traverse();
  while(!queue.empty())
    {
      Node* node = queue.front();
      queue.pop();
      node->left = NULL;
      node->right = NULL;
      delete node;
      node = NULL;
    }
  
}