This article is about a Python library I created to manage binary search trees. I will go over the following:

- Node class
- Insert method
- Lookup method
- Delete method
- Print method
- Comparing 2 trees
- Generator returning the tree elements one by one

You can checkout the library code on GitHub: git clone https://laurentluce@github.com/laurentluce/python-algorithms.git. This folder contains more libraries but we are just going to focus on the Binary Tree one.

As a reminder, here is a binary search tree definition (Wikipedia).

A binary search tree (BST) or ordered binary tree is a node-based binary tree data structure which has the following properties:

- The left subtree of a node contains only nodes with keys less than the node’s key.
- The right subtree of a node contains only nodes with keys greater than the node’s key.
- Both the left and right subtrees must also be binary search trees.

Here is an example of a binary search tree:

## Node class

We need to represent a tree node. To do that, we create a new class named Node with 3 attributes:

- Left node
- Right node
- Node’s data (same as key in the definition above.)

class Node: """ Tree node: left and right child + data which can be any object """ def __init__(self, data): """ Node constructor @param data node data object """ self.left = None self.right = None self.data = data

Let’s create a tree node containing the integer 8. You can pass any object for the data so it is flexible. When you create a node, both left and right node equal to None.

root = Node(8)

Note that we just created a tree with a single node.

## Insert method

We need a method to help us populate our tree. This method takes the node’s data as an argument and inserts a new node in the tree.

class Node: ... def insert(self, data): """ Insert new node with data @param data node data object to insert """ if self.data: if data < self.data: if self.left is None: self.left = Node(data) else: self.left.insert(data) elif data > self.data: if self.right is None: self.right = Node(data) else: self.right.insert(data) else: self.data = data

insert() is called recursively as we are locating the place where to add the new node.

Let’s add 3 nodes to our root node which we created above and let’s look at what the code does.

root.insert(3) root.insert(10) root.insert(1)

This is what happens when we add the second node (3):

- 1- root node’s method insert() is called with data = 3.
- 2- 3 is less than 8 and left child is None so we attach the new node to it.

This is what happens when we add the third node (10):

- 1- root node’s method insert() is called with data = 10.
- 2- 10 is greater than 8 and right child is None so we attach the new node to it.

This is what happens when we add the fourth node (1):

- 1- root node’s method insert() is called with data = 1.
- 2- 1 is less than 8 so the root’s left child (3) insert() method is called with data = 1. Note how we call the method on a subtree.
- 3- 1 is less than 3 and left child is None so we attach the new node to it.

This is how the tree looks like now:

Let’s continue and complete our tree so we can move on to the next section which is about looking up nodes in the tree.

root.insert(6) root.insert(4) root.insert(7) root.insert(14) root.insert(13)

The complete tree looks like this:

## Lookup method

We need a way to look for a specific node in the tree. We add a new method named lookup which takes a node’s data as an argument and returns the node if found or None if not. We also return the node’s parent for convenience.

class Node: ... def lookup(self, data, parent=None): """ Lookup node containing data @param data node data object to look up @param parent node's parent @returns node and node's parent if found or None, None """ if data < self.data: if self.left is None: return None, None return self.left.lookup(data, self) elif data > self.data: if self.right is None: return None, None return self.right.lookup(data, self) else: return self, parent

Let’s look up the node containing 6.

node, parent = root.lookup(6)

This is what happens when lookup() is called:

- 1- lookup() is called with data = 6, and default value parent = None.
- 2- data = 6 is less than root’s data which is 8.
- 3- root’s left child lookup() method is called with data = 6, parent = current node. Notice how we call lookup() on a subtree.
- 4- data = 6 is greater than node’s data which is now 3.
- 5- node’s right child lookup() method is called with data = 6 and parent = current node
- 6- node’s data is equal to 6 so we return it and its parent which is node 3.

## Delete method

The method delete() takes the data of the node to remove as an argument.

class Node: ... def delete(self, data): """ Delete node containing data @param data node's content to delete """ # get node containing data node, parent = self.lookup(data) if node is not None: children_count = node.children_count() ...

There are 3 possibilities to handle:

- 1- The node to remove has no child.
- 2- The node to remove has 1 child.
- 3- The node to remove has 2 children.

Let’s tackle the first possibility which is the easiest. We look for the node to remove and we set its parent’s left or right child to None. If it is the root node (no parent), we clear its data.

def delete(self, data): ... if children_count == 0: # if node has no children, just remove it if parent: if parent.left is node: parent.left = None else: parent.right = None del node else: self.data = None ...

Note: children_count() returns the number of children of a node.

Here is the function children_count:

class Node: ... def children_count(self): """ Returns the number of children @returns number of children: 0, 1, 2 """ cnt = 0 if self.left: cnt += 1 if self.right: cnt += 1 return cnt

For example, we want to remove node 1. Node 3 left child will be set to None.

root.delete(1)

Let’s look at the second possibility which is the node to be removed has one child. We replace the node with its child. We also handle the special case when the node is the root of the tree.

def delete(self, data): ... elif children_count == 1: # if node has 1 child # replace node with its child if node.left: n = node.left else: n = node.right if parent: if parent.left is node: parent.left = n else: parent.right = n del node else: self.left = n.left self.right = n.right self.data = n.data ...

For example, we want to remove node 14. Node 14 data will be set to 13 (its left child’s data) and its left child will be set to None.

root.delete(14)

Let’s look at the last possibility which is the node to be removed has 2 children. We replace its data with its successor’s data and we fix the successor’s parent’s child.

def delete(self, data): ... else: # if node has 2 children # find its successor parent = node successor = node.right while successor.left: parent = successor successor = successor.left # replace node data by its successor data node.data = successor.data # fix successor's parent's child if parent.left == successor: parent.left = successor.right else: parent.right = successor.right

For example, we want to remove node 3. We look for its successor by going right then left until we reach a leaf. Its successor is node 4. We replace 3 with 4. Node 4 doesn’t have a child so we set node 6 left child to None.

root.delete(3)

## Print method

We add a method to print the tree inorder. This method has no argument. We use recursion inside print_tree() to walk the tree depth-first. We first traverse the left subtree, then we print the root node then we traverse the right subtree.

class Node: ... def print_tree(self): """ Print tree content inorder """ if self.left: self.left.print_tree() print self.data, if self.right: self.right.print_tree()

Let’s print our tree:

root.print_tree()

The output will be: 1, 3, 4, 6, 7, 8, 10, 13, 14

## Comparing 2 trees

To compare 2 trees, we add a method which compares each subtree recursively. It returns False when one leaf is not the same in both trees. This includes 1 leaf missing in the other tree or the data is different. We need to pass the root of the tree to compare to as an argument.

class Node: ... def compare_trees(self, node): """ Compare 2 trees @param node tree's root node to compare to @returns True if the tree passed is identical to this tree """ if node is None: return False if self.data != node.data: return False res = True if self.left is None: if node.left: return False else: res = self.left.compare_trees(node.left) if res is False: return False if self.right is None: if node.right: return False else: res = self.right.compare_trees(node.right) return res

For example, we want to compare tree (3, 8, 10) with tree (3, 8, 11)

# root2 is the root of tree 2 root.compare_trees(root2)

This is what happens in the code when we call compare_trees().

- 1- The root node compare_trees() method is called with the tree to compare root node.
- 2- The root node has a left child so we call the left child compare_trees() method.
- 3- The left subtree comparison will return True.
- 2- The root node has a right child so we call the right child compare_trees() method.
- 3- The right subtree comparison will return False because the data is different.
- 4- compare_trees() will return False.

## Generator returning the tree elements one by one

It is sometimes useful to create a generator which returns the tree nodes values one by one. It is memory efficient as it doesn’t have to build the full list of nodes right away. Each time we call this method, it returns the next node value.

To do that, we use the yield keyword which returns an object and stops right there so the function will continue from there next time the method is called.

We cannot use recursion in this case so we use a stack.

Here is the code:

class Node: ... def tree_data(self): """ Generator to get the tree nodes data """ # we use a stack to traverse the tree in a non-recursive way stack = [] node = self while stack or node: if node: stack.append(node) node = node.left else: # we are returning so we pop the node and we yield it node = stack.pop() yield node.data node = node.right

For example, we want to access the tree nodes using a for loop:

for data in root.tree_data(): print data

Let’s look at what happens in the code with the same example we have been using:

- 1- The root node tree_data() method is called.
- 2- Node 8 is added to the stack. We go to the left child of 8.
- 3- Node 3 is added to the stack. We go to the left child of 3.
- 4- Node 1 is added to the stack. Node is set to None because there is no left child.
- 5- We pop a node which is Node 1. We yield it (returns 1 and stops here until tree_data() is called again.
- 6- tree_data() is called again because we are using it in a for loop.
- 7- Node is set to None because Node 1 doesn’t have a right child.
- 8- We pop a node which is Node 3. We yield it (returns 3 and stops here until tree_data() is called again.
- …

Here you go, I hope you enjoyed this tutorial. Don’t hesitate to add comments if you have any feedback.