Python實現 《算法導論 第三版》中的算法 第4章 分治策略

第4章 分治策略

1. 最大子數組問題

P40。認真讀一下4.1節,有一定算法基礎就可以看懂。這個問題的全稱是最大連續子數組問題。那麼,最大子數組究竟是什麼呢?

要想弄清楚最大(連續)子數組是什麼,首先需要明白(連續)子數組。(連續)子數組是數組中連續的幾個元素組成的數組。此時會有兩種不同理解,子數組可以爲空和子數組不能爲空。

  1. 子數組可以爲空,意思是把空數組也看成子數組。子數組如果爲空,認爲它的和就是0。所以在這種理解下,該問題的結果不可能爲負, 最小值爲0。對於元素全負的數組,最大子數組爲0。一般以該種理解爲主,因爲在集合論中,一個集合的子集可以爲空。
  2. 子數組不能爲空,意思是不把空數組看成子數組。那麼,子數組至少含有一個元素。在這種理解下,對於元素全負的數組,該問題的結果爲負。

對於既有正數也有負數或只有正數的數組,兩種理解的結果相同。只有對於元素全負的數組,這兩種理解的結果纔不同。

書中的方法採取的是第2種理解方法,它對於元素全負的數組會輸出一個負數結果。

def find_maximum_subarray(A, low, high):
    if low == high:
        return low, high, A[low]
    else:
        mid = (low + high)// 2
        left_low, left_high, left_sum = find_maximum_subarray(A, low, mid)
        right_low, right_high, right_sum = find_maximum_subarray(A, mid+1, high)
        cross_low, cross_high, cross_sum = find_max_crossing_subarray(A, low, mid, high)
        if left_sum >= right_sum and left_sum >= cross_sum:
            return left_low, left_high, left_sum
        elif right_sum >= left_sum and right_sum >= cross_sum:
            return right_low, right_high, right_sum
        else:
            return cross_low, cross_high, cross_sum
        
        
def find_max_crossing_subarray(A, low, mid, high):
    import math
    left_sum = -math.inf
    sum1 = 0
    left_ind = 0
    for i in range(mid, low-1, -1):
        sum1 += A[i]
        if sum1 > left_sum:
            left_sum = sum1
            left_ind = i
    
    right_sum = -math.inf
    sum2 = 0
    right_ind = 0
    for j in range(mid+1, high+1):
        sum2 += A[j]
        if sum2 > right_sum:
            right_sum = sum2
            right_ind = j
    
    return left_ind, right_ind, left_sum + right_sum


def main():
    A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
    print(find_maximum_subarray(A, 0, len(A)-1))
    B = [-6 -5, -2, -1, -3] 
    print(find_maximum_subarray(B, 0, len(B)-1))

 
if __name__ == '__main__':
    main()

2. 最大子數組問題的其他解法

P42練習4.1-2。題目要求暴力求解,運行時間爲Θ(n2)\Theta(n^2)。下面代碼中第一個函數符合題目要求,第二個函數是線性時間複雜度Θ(n)\Theta(n)

def find_maximum_subarray_n2(nums):
    import math
    maximum = -math.inf
    for i in range(0, len(nums)):
        temp = 0
        for j in range(i, len(nums)):
            temp += nums[j]
            if temp > maximum:
                maximum = temp
    return maximum


def find_maximum_subarray_n(nums):
    import math
    maximum = -math.inf
    temp = 0
    for n in nums:
        temp += n
        if temp > maximum:
            maximum = temp
        if temp < 0:
            temp = 0
    return maximum


def main():
    A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
    print(find_maximum_subarray_n2(A))
    print(find_maximum_subarray_n(A))
    B = [-6 -5, -2, -1, -3] 
    print(find_maximum_subarray_n2(B))
    print(find_maximum_subarray_n(B))
  

if __name__ == '__main__':
    main()

P42練習4.1-4和4.1-5。此時採用第1種理解方法,允許子數組爲空。給出三種方法,時間複雜度分別是Θ(n3)\Theta(n^3)Θ(n2)\Theta(n^2)Θ(n)\Theta(n)

class MaximumSubarray:
    def find_maximum_n3(self, A):
        maximum = 0
        for i in range(0, len(A)):
            for j in range(i+1, len(A)):
                temp = sum(A[i:j])
                if temp > maximum:
                    maximum = temp
        return maximum
    
    def find_maximum_n2(self, A):
        maximum = 0
        for i in range(0, len(A)):
            temp = 0
            for j in range(i, len(A)):
                temp += A[j]
                if temp > maximum:
                    maximum = temp
        return maximum
    
    def find_maximum_n(self, A):
        maximum = 0
        temp = 0
        for i in range(0, len(A)):
            temp += A[i]
            if temp > maximum:
                maximum = temp
            elif temp < 0:
                temp = 0
        return maximum
            
    
def main():
    A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
    m = MaximumSubarray()
    print(m.find_maximum_n3(A))
    print(m.find_maximum_n2(A))
    print(m.find_maximum_n(A))
    B = [-6 -5, -2, -1, -3]
    print(m.find_maximum_n3(B))
    print(m.find_maximum_n2(B))
    print(m.find_maximum_n(B))


if __name__ == '__main__':
    main()

3. Strassen算法

矩陣相乘(矩陣爲方陣,且行數列數爲2的冪次)的第一個暴力算法和第二個分治算法花了好長時間才寫完,Strassen算法還要寫個一兩天,真是太菜了…

# -*- coding: utf-8 -*-
def square_matrix_multiply(A, B):
    n = len(A)
    res = []
    for i in range(n):
        temp = [0] * n
        for j in range(n):
            for k in range(n):
                temp[j] += A[i][k] * B[k][j]
        res.append(temp)
    return res


def square_matrix_multiply_recursive(A, B):
    n = len(A)
    if n == 1:
        return [[A[0][0] * B[0][0]]]
    else:
        mid = n // 2
        A11 = [A[i][0:mid] for i in range(mid)] # Split A and B
        A12 = [A[i][mid:] for i in range(mid)]
        A21 = [A[i][0:mid] for i in range(mid,n)]
        A22 = [A[i][mid:] for i in range(mid,n)]
        B11 = [B[i][0:mid] for i in range(mid)]
        B12 = [B[i][mid:] for i in range(mid)]
        B21 = [B[i][0:mid] for i in range(mid,n)]
        B22 = [B[i][mid:] for i in range(mid,n)]
        
        C11_1 = square_matrix_multiply_recursive(A11, B11) # Compute the partitions of C
        C11_2 = square_matrix_multiply_recursive(A12, B21)
        C12_1 = square_matrix_multiply_recursive(A11, B12)
        C12_2 = square_matrix_multiply_recursive(A12, B22)
        C21_1 = square_matrix_multiply_recursive(A21, B11)
        C21_2 = square_matrix_multiply_recursive(A22, B21)
        C22_1 = square_matrix_multiply_recursive(A21, B12)
        C22_2 = square_matrix_multiply_recursive(A22, B22)
        
        C11 = [[C11_1[i][j] + C11_2[i][j] for j in range(mid)] for i in range(mid)]
        C12 = [[C12_1[i][j] + C12_2[i][j] for j in range(mid)] for i in range(mid)]
        C21 = [[C21_1[i][j] + C21_2[i][j] for j in range(mid)] for i in range(mid)]
        C22 = [[C22_1[i][j] + C22_2[i][j] for j in range(mid)] for i in range(mid)]
        C = [C11[i] + C12[i] for i in range(mid)] + [C21[i] + C22[i] for i in range(mid)]
        
        return C       
        

def main():
    A = [[1,2], [2,3]]
    B = [[1,1], [2,2]]
    print(square_matrix_multiply(A, B))
    print(square_matrix_multiply_recursive(A, B))
    A = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]]
    B = [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]]
    print(square_matrix_multiply(A, B))
    print(square_matrix_multiply_recursive(A, B))
    
    
if __name__ == '__main__':
    main()

Strassen算法終於寫完了,好像也沒那麼難哈哈哈

# -*- coding: utf-8 -*-
def square_matrix_multiply_strassen(A, B):
    n = len(A)
    if n == 1:
        return [[A[0][0] * B[0][0]]]
    else:
        mid = n // 2
        A11 = [A[i][0:mid] for i in range(mid)] # Split A and B
        A12 = [A[i][mid:] for i in range(mid)]
        A21 = [A[i][0:mid] for i in range(mid,n)]
        A22 = [A[i][mid:] for i in range(mid,n)]
        B11 = [B[i][0:mid] for i in range(mid)]
        B12 = [B[i][mid:] for i in range(mid)]
        B21 = [B[i][0:mid] for i in range(mid,n)]
        B22 = [B[i][mid:] for i in range(mid,n)]
        
        S1 = [[B12[i][j] - B22[i][j] for j in range(mid)] for i in range(mid)] # Compute S1-S10
        S2 = [[A11[i][j] + A12[i][j] for j in range(mid)] for i in range(mid)]
        S3 = [[A21[i][j] + A22[i][j] for j in range(mid)] for i in range(mid)]
        S4 = [[B21[i][j] - B11[i][j] for j in range(mid)] for i in range(mid)]
        S5 = [[A11[i][j] + A22[i][j] for j in range(mid)] for i in range(mid)]
        S6 = [[B11[i][j] + B22[i][j] for j in range(mid)] for i in range(mid)]
        S7 = [[A12[i][j] - A22[i][j] for j in range(mid)] for i in range(mid)]
        S8 = [[B21[i][j] + B22[i][j] for j in range(mid)] for i in range(mid)]
        S9 = [[A11[i][j] - A21[i][j] for j in range(mid)] for i in range(mid)]
        S10 = [[B11[i][j] + B12[i][j] for j in range(mid)] for i in range(mid)]
        
        P1 = square_matrix_multiply_strassen(A11, S1) # Compute P1-P7
        P2 = square_matrix_multiply_strassen(S2, B22)
        P3 = square_matrix_multiply_strassen(S3, B11)
        P4 = square_matrix_multiply_strassen(A22, S4)
        P5 = square_matrix_multiply_strassen(S5, S6)
        P6 = square_matrix_multiply_strassen(S7, S8)
        P7 = square_matrix_multiply_strassen(S9, S10)
        
        C11 = [[P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j] for j in range(mid)] for i in range(mid)]
        C12 = [[P1[i][j] + P2[i][j] for j in range(mid)] for i in range(mid)]
        C21 = [[P3[i][j] + P4[i][j] for j in range(mid)] for i in range(mid)]
        C22 = [[P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j] for j in range(mid)] for i in range(mid)]
        C = [C11[i] + C12[i] for i in range(mid)] + [C21[i] + C22[i] for i in range(mid)]
            
        return C    
        

def main():
    A = [[1,2], [2,3]]
    B = [[1,1], [2,2]]
    print(square_matrix_multiply_strassen(A, B))
    A = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]]
    B = [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]]
    print(square_matrix_multiply_strassen(A, B))
    
    
if __name__ == '__main__':
    main()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章