Interval Sum

Question

Thoughts

Solution

segmentTree

正确,但是在lintcode上memory limit exceeded

"""
Definition of Interval.
class Interval(object):
    def __init__(self, start, end):
        self.start = start
        self.end = end
"""


class Solution: 
    """
    @param A, queries: Given an integer array and an Interval list
                       The ith query is [queries[i-1].start, queries[i-1].end]
    @return: The result list
    """
    def intervalSum(self, A, queries):
        # write your code here
        segT = SegmentTree()
        length = len(A)
        root = segT.build(0, length)
        for idx, num in enumerate(A):
            segT.modify(root, idx, num)

        result = []
        for q in queries:
            res = 0
            if q > 0:
                res = segT.query(root, q.start, q.end)
            result.append(res)

        return result
"""
class SegmentTreeNode:
    def __init__(self, start, end, count):
        self.start = start
        self.end = end
        self.count = count
        self.left, self.right = None, None
"""

class SegmentTree:

    def __init__(self):
        self.root = None

    def build(self, start, end):
        if start > end:
            return None
        root = SegmentTreeNode(start, end, 0)

        if start < end:
            mid = (start + end) / 2
            root.left = self.build(start, mid)
            root.right = self.build(mid+1, end)
        else:
            root.count = 0

        return root

    def query(self, root, start, end):
        if not root:
            return 0
        rStart = root.start
        rEnd = root.end

        if (rStart > end) or ( rEnd < start):
            return 0

        #should be >= and <=
        #if query range if larger than the segment tree range, just return the count
        if rStart >= start and rEnd <= end:
            return root.count

        leftCount = 0
        rightCount = 0

        mid = (rStart + rEnd) / 2
        if start <= mid:
            if end > mid:
                #some range are in right part
                leftCount = self.query(root.left, start, mid)
            else:
                leftCount = self.query(root.left, start, end)
        if mid < end:
            if start <= mid:
                #some range are in left part
                #right part should start with mid + 1
                rightCount = self.query(root.right, mid + 1, end)
            else:
                rightCount = self.query(root.right, start, end)
        return leftCount + rightCount

    def modify(self, root, index, value):
        if root.start == index and root.end == index:
            root.count += value
            return

        mid = (root.start + root.end) / 2
        if root.start <= index and index <= mid:
            self.modify(root.left, index, value)
        elif root.end >= index and index > mid:
            self.modify(root.right, index, value)

        root.count = root.left.count + root.right.count

java code, not memory limit exceed

public class Solution {
    /**
     *@param A, queries: Given an integer array and an query list
     *@return: The result list
     */
     class SegmentTreeNode {
        public int start, end;
        public Long sum;
        public SegmentTreeNode left, right;
        public SegmentTreeNode(int start, int end, Long sum) {
              this.start = start;
              this.end = end;
              this.sum = sum;
              this.left = this.right = null;
        }
    }
    public SegmentTreeNode build(int start, int end, int[] A) {
        // write your code here
        if(start > end) {  // check core case
            return null;
        }

        SegmentTreeNode root = new SegmentTreeNode(start, end, 0L);

        if(start != end) {
            int mid = (start + end) / 2;
            root.left = build(start, mid, A);
            root.right = build(mid+1, end, A);

            root.sum = root.left.sum + root.right.sum;
        } else {
            root.sum =  Long.valueOf(A[start]);

        }
        return root;
    }
    public Long query(SegmentTreeNode root, int start, int end) {
        // write your code here
        if(start == root.start && root.end == end) { // 相等 
            return root.sum;
        }


        int mid = (root.start + root.end)/2;
        Long leftsum = 0L, rightsum = 0L;
        // 左子区
        if(start <= mid) {
            if( mid < end) { // 分裂 
                leftsum =  query(root.left, start, mid);
            } else { // 包含 
                leftsum = query(root.left, start, end);
            }
        }
        // 右子区
        if(mid < end) { // 分裂 3
            if(start <= mid) {
                rightsum = query(root.right, mid+1, end);
            } else { //  包含 
                rightsum = query(root.right, start, end);
            } 
        }  
        // else 就是不相交
        return leftsum + rightsum;
    }
    public ArrayList<Long> intervalSum(int[] A, 
                                       ArrayList<Interval> queries) {
        // write your code here
        SegmentTreeNode root = build(0, A.length - 1, A);
        ArrayList ans = new ArrayList<Long>();
        for(Interval in : queries) {
            ans.add(query(root, in.start, in.end));
        }
        return ans;
    }
}