Range Sum Query 2d - Mutable

question

Thoughts

cause it's mutablable, so we cannot use DP because every update we need to scan the matrix again.

注意点

  • 要注意在read的时候,包含与不包含的问题!

Solution

class NumMatrix(object):
    def __init__(self, matrix):
        """
        initialize your data structure here.
        :type matrix: List[List[int]]
        """
        if not matrix:
            return
        self.row_num = len(matrix)
        self.col_num = len(matrix[0])
        self.matrix = [[0 for i in range(self.col_num)] for j in range(self.row_num)]
        self.tree = [[0 for i in range(self.col_num+1)] for j in range(self.row_num+1)]

        for i in range(self.row_num):
            for j in range(self.col_num):
                self.update(i, j, matrix[i][j])

    def lastBit(self, x):
        return x & (~(x-1))

    def update(self, row, col, val):
        """
        update the element at matrix[row,col] to val.
        :type row: int
        :type col: int
        :type val: int
        :rtype: void
        """
        diff, self.matrix[row][col] = val - self.matrix[row][col], val
        row += 1
        col += 1
        while row <= self.row_num:
            cur_col = col
            while cur_col <= self.col_num:
                self.tree[row][cur_col] += diff
                cur_col += self.lastBit(cur_col)
            row += self.lastBit(row)


    def sumRegion(self, row1, col1, row2, col2):
        """
        sum of elements matrix[(row1,col1)..(row2,col2)], inclusive.
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        sum1 = self.read(row1, col1)
        sum2 = self.read(row2+1, col2+1)
        sum3 = self.read(row2+1, col1)
        sum4 = self.read(row1, col2+1)
        return sum2 + sum1 - sum3 - sum4

    def read(self, row, col):
        result = 0
        while row:
            cur_col = col
            while cur_col:
                result += self.tree[row][cur_col]
                cur_col -= self.lastBit(cur_col)
            row -= self.lastBit(row)
        return result