Python的加速模塊numba

關於numba的介紹有很多,就是一個可以把大量重複代碼即時編譯爲機器碼來加快程序運行速度的庫。優點是快,方便,但缺陷也很明顯,如很多類型不兼容,使用時不太靈活,必須把f方法內包含的所有方法加上裝飾器,在數值計算量小時反而會減慢速度等等...

import librosa
from numba import njit
import time

path = "/Users/birenjianmo/Desktop/learn/librosa/input/1你好.wav"
y, sr = librosa.load(path)
yy = y.tolist()

def logtime(f):
    def setf(*args, **kw):
        number = 100
        start = time.time()
        for i in range(number):
            r = f(*args, **kw)
        end = time.time()
        print( "%s 耗時:%s s" %  (f.__name__, end - start/number) )
        return r
    return setf

 
def gettopdata(y):
    data = []
    for i in range(1,len(y)-1):
        if y[i]>0 and y[i-1]<y[i]<y[i+1]:
            data.append(i)
    return data
 

@logtime
@njit(fastmath=True)
def findall_numba(a):
    result = []
    data = []
    for i in range(1,len(a)):
        data.append( a[i]-a[i-1] )
    
    for pos in range(5, len(data)//3):
        for i in range(len(data)-pos):
            if abs(data[i]  - data[i+pos]) < abs(data[i]) * 0.05:
                getp = True
                p = [a[i]]
                r = 1
    
                while r < (len(data)-i)//pos:
                    
                    for j in range(pos):
                        if abs(data[i+j] - data[i+pos*r+j]) > abs(data[i+j]) * 0.05:
                            getp = False
                            break
    
                    if not getp:
                        break
                    else:
                        p.append(a[i+pos*r])
                        r += 1
    
                if len(p) > 4:
                    result.append(p)
 
    return result

@logtime
def findall(a):
    result = []
    data = []
    for i in range(1,len(a)):
        data.append( a[i]-a[i-1] )
    
    for pos in range(5, len(data)//3):
        for i in range(len(data)-pos):
            if abs(data[i]  - data[i+pos]) < abs(data[i]) * 0.05:
                getp = True
                p = [a[i]]
                r = 1
    
                while r < (len(data)-i)//pos:
                    
                    for j in range(pos):
                        if abs(data[i+j] - data[i+pos*r+j]) > abs(data[i+j]) * 0.05:
                            getp = False
                            break
    
                    if not getp:
                        break
                    else:
                        p.append(a[i+pos*r])
                        r += 1
    
                if len(p) > 4:
                    result.append(p)
 
    return result


if __name__ == '__main__':
    a = gettopdata(yy)
    findall(a)
    findall_numba(a)

運行100次的平均時間 ,使用numba可以提升差不多7倍的速度

findall 耗時:1.180213589668274 s
findall_numba 耗時:0.17544831991195678 s

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章