Remove Node in Binary Search Tree

Problem

Given a root of Binary Search Tree with unique value for each node. Remove the node with given value. If there is no such a node with given value in the binary search tree, do nothing. You should keep the tree still a binary search tree after removal.

思路

Delete node 有三种情况

因为要delete,在find这个node的过程中要保留一个parent的变量

  1. leaf node

删掉这个node,把parent对这个node的reference设为null

  1. Node with only one child

delete the node,把parent对node的reference link到node的child

  1. Node with 2 children

    • find the minimum node of right subtree
    • replace the value of found node
    • delete the old duplicate node(case 1/2, cause minimum node should not have left child)

Helper Function

findNode() findRightMinimum()

Solution

"""
Definition of TreeNode:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left, self.right = None, None
"""
class Solution:
    """
    @param root: The root of the binary search tree.
    @param value: Remove the node with given value.
    @return: The root of the binary search tree after removal.
    """    
    def removeNode(self, root, value):
        # write your code here
        if not root: 
            return None
        dummy = TreeNode(0)
        dummy.left = root
        parent, target = self.findNode(root, value, dummy)
        #couldn't find target
        if not target:
            return root
        """
        找到了target, 分三种情况
        """
        if not target.left and not target.right:
            #leaf node
            if parent.left == target:
                parent.left = None
            else:    
                parent.right = None
        else:        
            if not target.left and target.right:
                #only right child
                if parent.left == target:
                   parent.left = target.right 
                else:
                    parent.right = target.right
            elif not target.right and target.left:
                #only left child 
                if parent.left == target:
                   parent.left = target.left
                else:
                    parent.right = target.left
            else:    
                #two children
                minParent, minNode = self.findRightMinimum(target.right)
                if minParent.left == minNode:
                    minParent.left = None
                else:
                    minParent.right = None
                target.val = minNode.val
        return dummy.left

    def findNode(self, root, value, parent):
        if not root: 
            return parent, None
        if root.val == value:
            return parent, root
        if root.val > value:
            return self.findNode(root.left, value, root)
        else:        
            return self.findNode(root.right, value, root)

    def findRightMinimum(self, root):
        """
        找到left most, root.left为null时停止
        """
        while root.left:  
            parent = root
            root = root.left
        return parent, root