Cython 並行優化示例

一、優化前

jupyter notebook 進入Cython環境%load_ext Cython

%%cython --a
import numpy as np
cimport numpy as np

from cython import boundscheck, wraparound
from cython.parallel import parallel, prange

@boundscheck(False)
@wraparound(False)
def cytest1(double[:,::1] a, double[:,::1] b, int ix1, int ix2, int iz1, int iz2):

    cdef int ix
    cdef int iz

    for ix in range(ix1, ix2):
        for iz in range(iz1, iz2):
            b[ix, iz] = 0.5*(a[ix+1, iz] - a[ix-1, iz])
    return b


@boundscheck(False)
@wraparound(False)
def cytest2(double[:,::1] a, double[:,::1] b, int ix1, int ix2, int iz1, int iz2):

    cdef int ix
    cdef int iz

    with nogil, parallel():
        for ix in prange(ix1, ix2):
            for iz in range(iz1, iz2):
                b[ix, iz] = 0.5*(a[ix+1, iz] - a[ix-1, iz])

    return b

測試代碼

import time


nx, nz = 1024, 1024

a = np.random.rand(nx, nz)
b = np.zeros_like(a)

Nit = 1000
ti = time.time()
for i in range(Nit):
    cytest1(a, b, 5, nx-5, 0, nz)
print('cytest1 : {:.3f} s.'.format(time.time() - ti))

ti = time.time()
for i in range(Nit):
    cytest2(a, b, 5, nx-5, 0, nz)
print('cytest2 : {:.3f} s.'.format(time.time() - ti))

# cytest1 : 1.239 s.
# cytest2 : 1.233 s.

setup.py:

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext


ext_modules = [
    Extension("stencil",
              ["stencil.pyx"],
              libraries=["m"],
              extra_compile_args=["-O3", "-ffast-math", "-march=native", "-fopenmp"],
              extra_link_args=['-fopenmp'],
              )
]

setup(
  name="stencil",
  cmdclass={"build_ext": build_ext},
  ext_modules=ext_modules
)

二、優化後

%%cython --compile-args=/openmp --link-args=/openmp 
from cython.parallel import parallel, prange
from cython import boundscheck, wraparound
from libc.math cimport sin

@boundscheck(False)
@wraparound(False)
def cytest1(double[:,::1] a, double[:,::1] b, int ix1, int ix2, int iz1, int iz2):

    cdef int ix
    cdef int iz

    for ix in range(ix1, ix2):
        for iz in range(iz1, iz2):
            b[ix, iz] =sin(a[ix+1, iz])
    return b


@boundscheck(False)
@wraparound(False)
def cytest2(double[:,::1] a, double[:,::1] b, int ix1, int ix2, int iz1, int iz2):

    cdef int ix
    cdef int iz

    with nogil, parallel():
        for ix in prange(ix1, ix2):
            for iz in range(iz1, iz2):
                b[ix, iz] = sin(a[ix+1, iz])

    return b
# 測試結果
cytest1 : 6.151 s.
cytest2 : 1.496 s.
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章