最快計算Mandelbrot的Python代碼

Python作爲動態語言,現在越來越流行,但是在使用中卻未必十全十美,其中運行的性能問題,便是其中之一。當程序中有很多for循環,並且迭代次數很多的情況下,性能問題尤其突出。當然,解決辦法也有很多,比如使用Cython便是一個好的解決辦法,也可以使用一些第三方的Python庫,如,PyOpenCI,PyCuda,Numbia等,但相比之下使用Numbia更爲簡潔,本文使用Numba來計算並生成Mandelbrot集圖片

例子說明

關於什麼是mandelbrot,不屬於本文討論範疇,有興趣參看維基百科,本文是在一個預設值的複合平面內,按照目標圖片的寬高生成複數集合,並對其中每個複數進行Mandelbrot計算迭代,在規定的迭代次數內,檢查其是否爲Mandelbrot集合內的點,返回迭代次數,並根據迭代次數映射到對應的顏色,最後繪製到一張800*800的PNG圖片上,並顯示圖片。以下就是一張生成的圖片。
Mandelbrot

1、一般的Python代碼實現

import numpy as np
from PIL import Image
import time


def mandelbrot(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z * z + c
    return 0


def mandelbrot_set(xmin, xmax, ymin, ymax, img, maxiter):
    width, height = img.size[0], img.size[1]
    r1 = np.linspace(xmin, xmax, width)
    r2 = np.linspace(ymin, ymax, height)

    [img.putpixel((idx1, idx2),
                  (mandelbrot(complex(r, i), maxiter) << 21) + (mandelbrot(complex(r, i), maxiter) << 10)
                  + mandelbrot(complex(r, i), maxiter) * 8) for idx1, r in enumerate(r1) for idx2, i in enumerate(r2)]


bitmap = Image.new("RGB", (800, 800), "white")

start = time.time()
mandelbrot_set(-2.0, 0.5, -1.25, 1.25, bitmap, 100)
print("執行時間 {} 秒".format(round(time.time() - start, 2)))

bitmap.show()

運行上述代碼:

執行時間 12.37

2、 使用numba進行優化

import time
from src.utils import *
from numba import jit, guvectorize, complex128, int32
import math


def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, maxiter):
    re = np.linspace(xmin, xmax, width, dtype=np.float64)
    im = np.linspace(ymin, ymax, height, dtype=np.float64)
    c = re + im[:, None]*1j

    n3 = mandelbrot_numpy(c, maxiter)

    # To handle row exchange issue.
    rows, row = n3.shape[0], math.floor(n3.shape[0]/2)
    for i in range(row):
        n3[[i, rows - 1 - i], :] = n3[[rows - 1 - i, i], :]

    return n3


@jit(int32(complex128, int32))
def mandelbrot(c, maxiter):
    real = 0
    imag = 0
    for n in range(maxiter):
        nreal = real * real - imag * imag + c.real
        imag = 2 * real * imag + c.imag
        real = nreal
        if real * real + imag * imag > 4.0:
            return n
    return 0


@guvectorize([(complex128[:], int32[:], int32[:])], '(n),()->(n)', target='parallel')
def mandelbrot_numpy(c, maxit, output):
    maxiter = maxit[0]
    for i in range(c.shape[0]):
        output[i] = mandelbrot(c[i], maxiter)


width = 800
height = 800
max_iter = 100

start = time.time()
n = mandelbrot_set(-2.0, 0.5, -1.25, 1.25, width, height, max_iter)
img = get_image(n, create_palette())
print("迭代執行時間 {} 秒".format(round(time.time() - start, 2)))

img.show()
迭代執行時間 0.26

可以看到運行效率提高100倍以上,而且電腦的風扇也不會像分機起飛一樣發出難聽的噪音。其核心概念時使用了向量化和並行計算的概念,你可以在Numbia的註解較容易的實現標量參數向量化計算,類似Numpy裏的通用函數frompyfunc類似。

總結

當你在機器學習或者大規模計算時,涉及較多循環,或者迭代次數很高的時候,多考慮一下是否可以使用第三方Python庫來提高計算效率,可用的手段有Cython,Numba,PyCuda,PyOpenCI等,使用這些手段往往會收到事半功倍的效果。

參考資料

https://numba.pydata.org/numba-doc/dev/user/vectorize.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章