動態規劃之矩陣鏈乘法問題

將窮舉所有情況的指數級運行時間降爲上限爲n^3

代碼:

#-*- coding: utf-8 -*-
import sys, time
def gk(i,j):
    return str(i)+','+str(j)
def matrix_chain_order(p):
    n = len(p)-1
    m, s = {}, {}
    for i in xrange(1, n+1):
        m[gk(i, i)] = 0
    for l in xrange(2, n+1):
        for i in xrange(1, n-l+2):
            j = i+l-1
            m[gk(i, j)] = sys.maxint
            for k in xrange(i, j):
                q = m[gk(i, k)]+m[gk(k+1, j)]+p[i-1]*p[k]*p[j]
                if q<m[gk(i, j)]:
                    m[gk(i, j)] = q
                    s[gk(i, j)] = k
    return m, s
def get_optimal_parens(s, i, j):
    res = ''
    if i == j:
        return "A"+str(j)
    else:
        res += "("
        res += get_optimal_parens(s, i, s[gk(i, j)])
        res += get_optimal_parens(s, s[gk(i, j)]+1, j)
        res +=  ")"
        return res
def main():
    p = [30,35,15,5,10,20,25,5,16,34,28,19,66,34,78,55,23]
    m, s = matrix_chain_order(p)
    print '總共的乘法次數是:', m[gk(1, len(p)-1)]
    print '矩陣鏈乘法的方案是', get_optimal_parens(s, 1, len(p)-1)
if __name__ == '__main__':
    b = time.time()
    main()
    print 'total run time is:', time.time()-b

結果爲:

>>>
總共的乘法次數是: 84115
矩陣鏈乘法的方案是 ((A1(A2(A3(A4(A5(A6A7))))))((((((((A8A9)A10)A11)A12)A13)A14)A15)A16))
total run time is: 0.0369999408722

樸素的遞歸算法是:

#-*- coding: utf-8 -*-
import sys, time
''' 計算第i到第j個矩陣鏈相乘的乘法次數 '''
def recursive_matrix_chain(p, i, j):
    if i == j:
        return 0
    m = sys.maxint
    for k in xrange(i, j):
        q = recursive_matrix_chain(p, i, k) + recursive_matrix_chain(p, k+1, j) + p[i-1]*p[k]*p[j]
        if q < m:
            m = q
    return m
def main():
    p = [30,35,15,5,10,20,25,5,16,34,28,19,66,34,78,55,23]
    print recursive_matrix_chain(p, 1, len(p)-1)
if __name__ == '__main__':
    b = time.time()
    main()
    print 'total run time is:', time.time()-b

結果爲:

>>>
84115
total run time is: 11.9000000954

這種樸素遞歸的方式的時間運行複雜度是2^n,指數級的,因爲這個問題滿足動態規劃的要求:最優子結構,子問題無關且重疊,所以這種樸素的遞歸方式可以用加入備忘的方式來實現動態規劃的算法:

#-*- coding: utf-8 -*-
import sys, time
gk = lambda i,j:str(i)+','+str(j)
MAX = sys.maxint
def memoized_matrix_chain(p):
    n = len(p)-1
    m = {}
    for i in xrange(1, n+1):
        for j in xrange (i, n+1):
            m[gk(i, j)] = MAX
    return lookup_chain(m, p, 1, n)
def lookup_chain(m, p, i, j):
    if m[gk(i, j)] < MAX:
        return m[gk(i, j)]
    if i == j:
        m[gk(i, j)] = 0
    else:
        for k in xrange(i, j):
            q = lookup_chain(m, p, i, k) + lookup_chain(m, p, k+1, j) + p[i-1]*p[k]*p[j]
            if q < m[gk(i, j)]:
                m[gk(i, j)] = q
    return m[gk(i, j)]
def main():
    p = [30,35,15,5,10,20,25,5,16,34,28,19,66,34,78,55,23]
    print memoized_matrix_chain(p)
if __name__ == '__main__':
    b = time.time()
    main()
    print 'total run time is:', time.time()-b

結果爲:

>>>
84115
total run time is: 0.0209999084473


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章