Python作爲動態語言,現在越來越流行,但是在使用中卻未必十全十美,其中運行的性能問題,便是其中之一。當程序中有很多for循環,並且迭代次數很多的情況下,性能問題尤其突出。當然,解決辦法也有很多,比如使用Cython便是一個好的解決辦法,也可以使用一些第三方的Python庫,如,PyOpenCI,PyCuda,Numbia等,但相比之下使用Numbia更爲簡潔,本文使用Numba來計算並生成Mandelbrot集圖片
例子說明
關於什麼是mandelbrot,不屬於本文討論範疇,有興趣參看維基百科,本文是在一個預設值的複合平面內,按照目標圖片的寬高生成複數集合,並對其中每個複數進行Mandelbrot計算迭代,在規定的迭代次數內,檢查其是否爲Mandelbrot集合內的點,返回迭代次數,並根據迭代次數映射到對應的顏色,最後繪製到一張800*800的PNG圖片上,並顯示圖片。以下就是一張生成的圖片。
1、一般的Python代碼實現
import numpy as np
from PIL import Image
import time
def mandelbrot(c, maxiter):
z = c
for n in range(maxiter):
if abs(z) > 2:
return n
z = z * z + c
return 0
def mandelbrot_set(xmin, xmax, ymin, ymax, img, maxiter):
width, height = img.size[0], img.size[1]
r1 = np.linspace(xmin, xmax, width)
r2 = np.linspace(ymin, ymax, height)
[img.putpixel((idx1, idx2),
(mandelbrot(complex(r, i), maxiter) << 21) + (mandelbrot(complex(r, i), maxiter) << 10)
+ mandelbrot(complex(r, i), maxiter) * 8) for idx1, r in enumerate(r1) for idx2, i in enumerate(r2)]
bitmap = Image.new("RGB", (800, 800), "white")
start = time.time()
mandelbrot_set(-2.0, 0.5, -1.25, 1.25, bitmap, 100)
print("執行時間 {} 秒".format(round(time.time() - start, 2)))
bitmap.show()
運行上述代碼:
執行時間 12.37 秒
2、 使用numba進行優化
import time
from src.utils import *
from numba import jit, guvectorize, complex128, int32
import math
def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, maxiter):
re = np.linspace(xmin, xmax, width, dtype=np.float64)
im = np.linspace(ymin, ymax, height, dtype=np.float64)
c = re + im[:, None]*1j
n3 = mandelbrot_numpy(c, maxiter)
# To handle row exchange issue.
rows, row = n3.shape[0], math.floor(n3.shape[0]/2)
for i in range(row):
n3[[i, rows - 1 - i], :] = n3[[rows - 1 - i, i], :]
return n3
@jit(int32(complex128, int32))
def mandelbrot(c, maxiter):
real = 0
imag = 0
for n in range(maxiter):
nreal = real * real - imag * imag + c.real
imag = 2 * real * imag + c.imag
real = nreal
if real * real + imag * imag > 4.0:
return n
return 0
@guvectorize([(complex128[:], int32[:], int32[:])], '(n),()->(n)', target='parallel')
def mandelbrot_numpy(c, maxit, output):
maxiter = maxit[0]
for i in range(c.shape[0]):
output[i] = mandelbrot(c[i], maxiter)
width = 800
height = 800
max_iter = 100
start = time.time()
n = mandelbrot_set(-2.0, 0.5, -1.25, 1.25, width, height, max_iter)
img = get_image(n, create_palette())
print("迭代執行時間 {} 秒".format(round(time.time() - start, 2)))
img.show()
迭代執行時間 0.26 秒
可以看到運行效率提高100倍以上,而且電腦的風扇也不會像分機起飛一樣發出難聽的噪音。其核心概念時使用了向量化和並行計算的概念,你可以在Numbia的註解較容易的實現標量參數向量化計算,類似Numpy裏的通用函數frompyfunc類似。
總結
當你在機器學習或者大規模計算時,涉及較多循環,或者迭代次數很高的時候,多考慮一下是否可以使用第三方Python庫來提高計算效率,可用的手段有Cython,Numba,PyCuda,PyOpenCI等,使用這些手段往往會收到事半功倍的效果。