Count of Smaller Number before itself

Question

Give you an integer array (index from 0 to n-1, where n is the size of this array, value from 0 to 10000) . For each element Ai in the array, count the number of element before this element Ai is smaller than it and return count number array.

Example For array [1,2,7,8,5], return [0,1,2,3,2]

Thoughts

用线段树来做,插入一个number后就立刻query,就可以得到结果

Solution

class Solution:
    """
    @param A: A list of integer
    @return: Count the number of element before this element 'ai' is 
             smaller than it and return count number list
    """
    def countOfSmallerNumberII(self, A):
        # write your code here
        segT = SegmentTree()
        root = segT.build(0, 10000)
        result = []
        for num in A:
            res = 0
            segT.modify(root, num, 1)
            if num > 0:
                res = segT.query(root, 0, num - 1)
            result.append(res)
        return result


"""
class SegmentTreeNode:
    def __init__(self, start, end, count):
        self.start = start
        self.end = end
        self.count = count
        self.left, self.right = None, None
"""

class SegmentTree:

    def __init__(self):
        self.root = None

    def build(self, start, end):
        if start > end:
            return None
        root = SegmentTreeNode(start, end, 0)

        if start < end:
            mid = (start + end) / 2
            root.left = self.build(start, mid)
            root.right = self.build(mid+1, end)
        else:
            root.count = 0

        return root

    def query(self, root, start, end):
        if not root:
            return 0
        rStart = root.start
        rEnd = root.end

        if (rStart > end) or ( rEnd < start):
            return 0

        #should be >= and <=
        #if query range if larger than the segment tree range, just return the count
        if rStart >= start and rEnd <= end:
            return root.count

        leftCount = 0
        rightCount = 0

        mid = (rStart + rEnd) / 2
        if start <= mid:
            if end > mid:
                #some range are in right part
                leftCount = self.query(root.left, start, mid)
            else:
                leftCount = self.query(root.left, start, end)
        if mid < end:
            if start <= mid:
                #some range are in left part
                #right part should start with mid + 1
                rightCount = self.query(root.right, mid + 1, end)
            else:
                rightCount = self.query(root.right, start, end)
        return leftCount + rightCount

    def modify(self, root, index, value):
        if root.start == index and root.end == index:
            root.count += value
            return

        mid = (root.start + root.end) / 2
        if root.start <= index and index <= mid:
            self.modify(root.left, index, value)
        elif root.end >= index and index > mid:
            self.modify(root.right, index, value)

        root.count = root.left.count + root.right.count