Friday, June 12, 2015

Binary Search Tree implementation in Java

Binary search trees keep their keys in sorted order, so that lookup and other operations can use the principle of binary search: when looking for a key in a tree (or a place to insert a new key), they traverse the tree from root to leaf, making comparisons to keys stored in the nodes of the tree and deciding, based on the comparison, to continue searching in the left or right subtrees. This post is not about the basics of BST, in case you are interested please check this pdf.

I will first declare an interface ITree which we assume all our tree implementations to implement:
public interface Tree<T> {
    public boolean insert(T value);
    public T delete(T value);
    public void clear();
    public boolean contains(T value);
    public int size();
}

This interface can be implemented by the class as follows:
public class BinarySearchTree<T extends Comparable<T>> implements Tree<T> {
    protected Node<T> root = null;
    protected int size = 0;
    public BinarySearchTree() { }
}

We have used generics and we also want elements to be mutually comparable. We will first provide implementation of insert method as:
@Override
public boolean insert(T value) {
     Node<T> nodeAdded = this.insertValue(value);
     return (nodeAdded != null);
}
protected Node<T> insertValue(T value) {
     Node<T> newNode = getNewNode(value);

     // If root is null, assign
     if (root == null) {
       root = newNode;
       size++;
       return newNode;
     }

     Node<T> currentNode = root;
     while (currentNode != null) {
        if (newNode.getData().compareTo(currentNode.getData()) <= 0) { // Less than or equal to goes left
           if(currentNode.getLeft() == null) {
              insertNodeToLeft(currentNode, newNode);
              break;
           }
           currentNode = currentNode.getLeft();
        } else { // Greater than goes right
           if (currentNode.getRight() == null) {
              insertNodeToRight(currentNode, newNode);
              break;
           }
          currentNode = currentNode.getRight();
        }
     }
     return newNode;
}

The insert method will insert the node in left or right branch based on its value. The deletion part is slightly tricky. The process of deletion is as follows:
  1. First locate the node that needs to be deleted.
  2. Then find the replacement node for this node. We can select from left or right subtree and we do not want to select from one subtree all the time as it can lead to skewed tree.
  3. Then we need to delete this node and replace it with its replacement found in step 2.
This is performed by the following code:
        @Override
 public T delete(T value) {
    Node<T> nodeToRemove = this.deleteValue(value);
    return ((nodeToRemove!=null)?nodeToRemove.getData():null);
 }
 
 protected Node<T> deleteValue(T value) {
    Node<T> nodeToRemoved = this.searchAndGetNode(value);
    if (nodeToRemoved != null) nodeToRemoved = deleteNode(nodeToRemoved);
    return nodeToRemoved;
 }

 protected Node<T> searchAndGetNode(T value) {
    Node<T> currentNode = root;
    while (currentNode != null && currentNode.getData() != null) {
       int valueComparison = value.compareTo(currentNode.getData());
       if (valueComparison == 0) {
         return currentNode;
       } else if (valueComparison < 0) {
         currentNode = currentNode.getLeft();
       } else {
         currentNode = currentNode.getRight();
       }
    }
  return null;
 }

 protected Node<T> deleteNode(Node<T> nodeToRemoved) {
    if (nodeToRemoved != null) {
      Node<T> replacementNode = this.getReplacementNode(nodeToRemoved);
      deleteAndReplaceNode(nodeToRemoved, replacementNode);
    }
  return nodeToRemoved;
 }

 protected Node<T> getReplacementNode(Node<T> nodeToRemoved) {
    Node<T> replacement = null;
    if (nodeToRemoved.getLeft() != null && nodeToRemoved.getRight() == null) { // Using the less subtree
      replacement = nodeToRemoved.getLeft();
    } else if (nodeToRemoved.getRight() != null && nodeToRemoved.getLeft() == null) {
      // Using the greater subtree (there is no lesser subtree)
      replacement = nodeToRemoved.getRight();
    } else if (nodeToRemoved.getRight() != null && nodeToRemoved.getLeft() != null) {
      // Two children add some randomness to deletions, so we don't always use the greatest/least on deletion
      if (toggleReplacementNodeSelection) {
         replacement = this.getLargestNodeInSubTree(nodeToRemoved.getLeft());
         if (replacement == null)
           replacement = nodeToRemoved.getLeft();
         } else {
           replacement = this.getSmallestNodeInSubTree(nodeToRemoved.getRight());
           if (replacement == null)
             replacement = nodeToRemoved.getRight();
         }
         toggleReplacementNodeSelection = !toggleReplacementNodeSelection;
   }
   return replacement;
 }

 protected void deleteAndReplaceNode(Node<T> nodeToDelete, Node<T> replacementNode) {
    if (replacementNode != null) {
      // Record left and right child of replacementNode for later use
      Node<T> leftChildOfReplacementNode = replacementNode.getLeft();
      Node<T> rightChildOfReplacementNode = replacementNode.getRight();

      // For left and right children of nodeToDelete, new parent is replacementNode
      Node<T> leftChildOfNodeToDelete = nodeToDelete.getLeft();
      if (leftChildOfNodeToDelete != null && !leftChildOfNodeToDelete.equals(replacementNode)) {
         replacementNode.setLeft(leftChildOfNodeToDelete);
         leftChildOfNodeToDelete.setParent(replacementNode);
      }
   
      Node<T> rightChildOfNodeToDelete = nodeToDelete.getRight();
      if (rightChildOfNodeToDelete != null && !rightChildOfNodeToDelete.equals(replacementNode)) {
         replacementNode.setRight(rightChildOfNodeToDelete);
         rightChildOfNodeToDelete.setParent(replacementNode);
      }

      // Update the link of parent of replacementNode as well. For the parent of replacementNode kids of replacementNode are its new kids. In short grand-kids are kids now.
      Node<T> parentOfReplacementNode = replacementNode.parent;
      if (parentOfReplacementNode != null && !parentOfReplacementNode.equals(nodeToDelete)) {
         Node<T> leftChildOfParentOfReplacementNode = parentOfReplacementNode.getLeft();
         Node<T> rightChildOfParentOfReplacementNode = parentOfReplacementNode.getRight();

         // Check whether the replacementNode is left or right child of its parent.
         if (leftChildOfParentOfReplacementNode != null && leftChildOfParentOfReplacementNode.equals(replacementNode)) {
            parentOfReplacementNode.setLeft(rightChildOfReplacementNode);
            if (rightChildOfReplacementNode != null)
               rightChildOfReplacementNode.setParent(parentOfReplacementNode);
         } else if (rightChildOfParentOfReplacementNode != null && rightChildOfParentOfReplacementNode.equals(replacementNode)) {
            parentOfReplacementNode.setRight(leftChildOfReplacementNode);
            if (leftChildOfReplacementNode != null)
               leftChildOfReplacementNode.setParent(parentOfReplacementNode);
         }
       }
     } 

     // Update the link in the tree from the nodeToRemoved to the replacementNode
     Node<T> parentOfNodeToDelete = nodeToDelete.getParent();

     if (parentOfNodeToDelete == null) {
        // We are deleting root node. So replacing the root node.
        root = replacementNode;
        if (root != null) root.parent = null;
     } else if (parentOfNodeToDelete.getLeft() != null && (parentOfNodeToDelete.getLeft().getData().compareTo(nodeToDelete.getData()) == 0)) {
        parentOfNodeToDelete.setLeft(replacementNode);
        if (replacementNode != null) replacementNode.parent = parentOfNodeToDelete;
     } else if (parentOfNodeToDelete.getRight() != null && (parentOfNodeToDelete.getRight().getData().compareTo(nodeToDelete.getData()) == 0)) {
        parentOfNodeToDelete.setRight(replacementNode);
        if (replacementNode != null) replacementNode.parent = parentOfNodeToDelete;
     }
     size--;
 }

Now the remaining methods are relatively easy ones.
@Override
 public void clear() {
  root = null;
  size = 0;
 }
 @Override
 public boolean contains(T value) {
  Node<T> node = searchAndGetNode(value);
  return (node != null);
 }
 @Override
 public int size() {
  return size;
 }

This completes the implementation of BST in Java.

No comments: