第4章 分治策略
1. 最大子數組問題
P40。認真讀一下4.1節,有一定算法基礎就可以看懂。這個問題的全稱是最大連續子數組問題。那麼,最大子數組究竟是什麼呢?
要想弄清楚最大(連續)子數組是什麼,首先需要明白(連續)子數組。(連續)子數組是數組中連續的幾個元素組成的數組。此時會有兩種不同理解,子數組可以爲空和子數組不能爲空。
- 子數組可以爲空,意思是把空數組也看成子數組。子數組如果爲空,認爲它的和就是0。所以在這種理解下,該問題的結果不可能爲負, 最小值爲0。對於元素全負的數組,最大子數組爲0。一般以該種理解爲主,因爲在集合論中,一個集合的子集可以爲空。
- 子數組不能爲空,意思是不把空數組看成子數組。那麼,子數組至少含有一個元素。在這種理解下,對於元素全負的數組,該問題的結果爲負。
對於既有正數也有負數或只有正數的數組,兩種理解的結果相同。只有對於元素全負的數組,這兩種理解的結果纔不同。
書中的方法採取的是第2種理解方法,它對於元素全負的數組會輸出一個負數結果。
def find_maximum_subarray(A, low, high):
if low == high:
return low, high, A[low]
else:
mid = (low + high)// 2
left_low, left_high, left_sum = find_maximum_subarray(A, low, mid)
right_low, right_high, right_sum = find_maximum_subarray(A, mid+1, high)
cross_low, cross_high, cross_sum = find_max_crossing_subarray(A, low, mid, high)
if left_sum >= right_sum and left_sum >= cross_sum:
return left_low, left_high, left_sum
elif right_sum >= left_sum and right_sum >= cross_sum:
return right_low, right_high, right_sum
else:
return cross_low, cross_high, cross_sum
def find_max_crossing_subarray(A, low, mid, high):
import math
left_sum = -math.inf
sum1 = 0
left_ind = 0
for i in range(mid, low-1, -1):
sum1 += A[i]
if sum1 > left_sum:
left_sum = sum1
left_ind = i
right_sum = -math.inf
sum2 = 0
right_ind = 0
for j in range(mid+1, high+1):
sum2 += A[j]
if sum2 > right_sum:
right_sum = sum2
right_ind = j
return left_ind, right_ind, left_sum + right_sum
def main():
A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
print(find_maximum_subarray(A, 0, len(A)-1))
B = [-6 -5, -2, -1, -3]
print(find_maximum_subarray(B, 0, len(B)-1))
if __name__ == '__main__':
main()
2. 最大子數組問題的其他解法
P42練習4.1-2。題目要求暴力求解,運行時間爲。下面代碼中第一個函數符合題目要求,第二個函數是線性時間複雜度。
def find_maximum_subarray_n2(nums):
import math
maximum = -math.inf
for i in range(0, len(nums)):
temp = 0
for j in range(i, len(nums)):
temp += nums[j]
if temp > maximum:
maximum = temp
return maximum
def find_maximum_subarray_n(nums):
import math
maximum = -math.inf
temp = 0
for n in nums:
temp += n
if temp > maximum:
maximum = temp
if temp < 0:
temp = 0
return maximum
def main():
A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
print(find_maximum_subarray_n2(A))
print(find_maximum_subarray_n(A))
B = [-6 -5, -2, -1, -3]
print(find_maximum_subarray_n2(B))
print(find_maximum_subarray_n(B))
if __name__ == '__main__':
main()
P42練習4.1-4和4.1-5。此時採用第1種理解方法,允許子數組爲空。給出三種方法,時間複雜度分別是、、。
class MaximumSubarray:
def find_maximum_n3(self, A):
maximum = 0
for i in range(0, len(A)):
for j in range(i+1, len(A)):
temp = sum(A[i:j])
if temp > maximum:
maximum = temp
return maximum
def find_maximum_n2(self, A):
maximum = 0
for i in range(0, len(A)):
temp = 0
for j in range(i, len(A)):
temp += A[j]
if temp > maximum:
maximum = temp
return maximum
def find_maximum_n(self, A):
maximum = 0
temp = 0
for i in range(0, len(A)):
temp += A[i]
if temp > maximum:
maximum = temp
elif temp < 0:
temp = 0
return maximum
def main():
A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
m = MaximumSubarray()
print(m.find_maximum_n3(A))
print(m.find_maximum_n2(A))
print(m.find_maximum_n(A))
B = [-6 -5, -2, -1, -3]
print(m.find_maximum_n3(B))
print(m.find_maximum_n2(B))
print(m.find_maximum_n(B))
if __name__ == '__main__':
main()
3. Strassen算法
矩陣相乘(矩陣爲方陣,且行數列數爲2的冪次)的第一個暴力算法和第二個分治算法花了好長時間才寫完,Strassen算法還要寫個一兩天,真是太菜了…
# -*- coding: utf-8 -*-
def square_matrix_multiply(A, B):
n = len(A)
res = []
for i in range(n):
temp = [0] * n
for j in range(n):
for k in range(n):
temp[j] += A[i][k] * B[k][j]
res.append(temp)
return res
def square_matrix_multiply_recursive(A, B):
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
else:
mid = n // 2
A11 = [A[i][0:mid] for i in range(mid)] # Split A and B
A12 = [A[i][mid:] for i in range(mid)]
A21 = [A[i][0:mid] for i in range(mid,n)]
A22 = [A[i][mid:] for i in range(mid,n)]
B11 = [B[i][0:mid] for i in range(mid)]
B12 = [B[i][mid:] for i in range(mid)]
B21 = [B[i][0:mid] for i in range(mid,n)]
B22 = [B[i][mid:] for i in range(mid,n)]
C11_1 = square_matrix_multiply_recursive(A11, B11) # Compute the partitions of C
C11_2 = square_matrix_multiply_recursive(A12, B21)
C12_1 = square_matrix_multiply_recursive(A11, B12)
C12_2 = square_matrix_multiply_recursive(A12, B22)
C21_1 = square_matrix_multiply_recursive(A21, B11)
C21_2 = square_matrix_multiply_recursive(A22, B21)
C22_1 = square_matrix_multiply_recursive(A21, B12)
C22_2 = square_matrix_multiply_recursive(A22, B22)
C11 = [[C11_1[i][j] + C11_2[i][j] for j in range(mid)] for i in range(mid)]
C12 = [[C12_1[i][j] + C12_2[i][j] for j in range(mid)] for i in range(mid)]
C21 = [[C21_1[i][j] + C21_2[i][j] for j in range(mid)] for i in range(mid)]
C22 = [[C22_1[i][j] + C22_2[i][j] for j in range(mid)] for i in range(mid)]
C = [C11[i] + C12[i] for i in range(mid)] + [C21[i] + C22[i] for i in range(mid)]
return C
def main():
A = [[1,2], [2,3]]
B = [[1,1], [2,2]]
print(square_matrix_multiply(A, B))
print(square_matrix_multiply_recursive(A, B))
A = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]]
B = [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]]
print(square_matrix_multiply(A, B))
print(square_matrix_multiply_recursive(A, B))
if __name__ == '__main__':
main()
Strassen算法終於寫完了,好像也沒那麼難哈哈哈
# -*- coding: utf-8 -*-
def square_matrix_multiply_strassen(A, B):
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
else:
mid = n // 2
A11 = [A[i][0:mid] for i in range(mid)] # Split A and B
A12 = [A[i][mid:] for i in range(mid)]
A21 = [A[i][0:mid] for i in range(mid,n)]
A22 = [A[i][mid:] for i in range(mid,n)]
B11 = [B[i][0:mid] for i in range(mid)]
B12 = [B[i][mid:] for i in range(mid)]
B21 = [B[i][0:mid] for i in range(mid,n)]
B22 = [B[i][mid:] for i in range(mid,n)]
S1 = [[B12[i][j] - B22[i][j] for j in range(mid)] for i in range(mid)] # Compute S1-S10
S2 = [[A11[i][j] + A12[i][j] for j in range(mid)] for i in range(mid)]
S3 = [[A21[i][j] + A22[i][j] for j in range(mid)] for i in range(mid)]
S4 = [[B21[i][j] - B11[i][j] for j in range(mid)] for i in range(mid)]
S5 = [[A11[i][j] + A22[i][j] for j in range(mid)] for i in range(mid)]
S6 = [[B11[i][j] + B22[i][j] for j in range(mid)] for i in range(mid)]
S7 = [[A12[i][j] - A22[i][j] for j in range(mid)] for i in range(mid)]
S8 = [[B21[i][j] + B22[i][j] for j in range(mid)] for i in range(mid)]
S9 = [[A11[i][j] - A21[i][j] for j in range(mid)] for i in range(mid)]
S10 = [[B11[i][j] + B12[i][j] for j in range(mid)] for i in range(mid)]
P1 = square_matrix_multiply_strassen(A11, S1) # Compute P1-P7
P2 = square_matrix_multiply_strassen(S2, B22)
P3 = square_matrix_multiply_strassen(S3, B11)
P4 = square_matrix_multiply_strassen(A22, S4)
P5 = square_matrix_multiply_strassen(S5, S6)
P6 = square_matrix_multiply_strassen(S7, S8)
P7 = square_matrix_multiply_strassen(S9, S10)
C11 = [[P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j] for j in range(mid)] for i in range(mid)]
C12 = [[P1[i][j] + P2[i][j] for j in range(mid)] for i in range(mid)]
C21 = [[P3[i][j] + P4[i][j] for j in range(mid)] for i in range(mid)]
C22 = [[P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j] for j in range(mid)] for i in range(mid)]
C = [C11[i] + C12[i] for i in range(mid)] + [C21[i] + C22[i] for i in range(mid)]
return C
def main():
A = [[1,2], [2,3]]
B = [[1,1], [2,2]]
print(square_matrix_multiply_strassen(A, B))
A = [[1,2,3,4], [2,3,4,5], [3,4,5,6], [4,5,6,7]]
B = [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]]
print(square_matrix_multiply_strassen(A, B))
if __name__ == '__main__':
main()