讓代碼飛起來——高性能Julia學習筆記

最近有個項目是計算密集型的, 最開始用 TS 快速實現了算法原型, 後來改用 Go 重寫, Go 的 goroutine 用起來還是蠻爽的, 很容易把所有 cpu core 跑滿。 不過隨着代碼逐漸複雜, 感覺用 Go 還是沒有動態語言寫起來爽, 性能也沒有達到極致, 跟 C/C++/Rust 還是有一定差距,似乎對 GPU 和 SIMD 支持也不太好(不敢說對 Go 精通, 可能是我沒找到合適的打開方式吧)。 一開始打算用 Rust 嘗試一下, 之前用過一下, 性能確實可以(畢竟沒有 runtime/GC),結果前段時間 Julia 發佈 1.0,看了一下語法、性能等各方面都很適合, 遂決定用 Julia 寫。

本文記錄一下學習 Julia HPC 方面的經歷。

 

使用 Julia

網上有很多 Julia 的教程, 推薦幾個大家自己去看看吧:

 

High Performance Computing Julia

主要參考《Julia High Performance》,書裏用的 Julia0.4, 很多代碼已經跑不起來了, 我參考官網文檔修改了一下。 後面有時間會根據官網文檔再整理一些 HPC 相關的資料。

測試機器爲:

julia> versioninfo()
Julia Version 1.0.1
Commit 0d713926f8 (2018-09-29 19:05 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin14.5.0)
  CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.0 (ORCJIT, haswell)

 

有多快

參考官網 Benchmark

benchmarks.svguploading.4e448015.gif轉存失敗重新上傳取消Julia benchmark

讓我驚訝的是, LuaJIT 居然比 Rust 還快!!!

 

爲什麼如此快

Julia 設計之初就重點考慮了速度, 它的快很大程度上來源於 LLVM, JIT 以及類型設計。 Julia 有類型推斷,編譯的時候會根據不同 type 生成不同的特定代碼,叫 code specialization,然後運行的時候會根據參數類型選擇最適合的方法,即 Multiple dispatch。

 

性能分析工具

沒有分析的優化都是耍流氓!

性能分析工具主要有如下幾個:

 

@time

@time sqrt.(rand(1000));

輸出結果:

julia> @time sqrt.(rand(1000));
  0.057799 seconds (187.82 k allocations: 9.542 MiB, 9.82% gc time)

julia> @time sqrt.(rand(1000));
  0.000022 seconds (8 allocations: 16.063 KiB)

這裏;是爲了不輸入結果值。 因爲rand(1000)返回的是一個 vector, sqrt.即可作用於 vector 中的每一個元素, 類似 numpy 中的 broadcast 吧。

注意, 第一次運行的時候 Julia 會編譯代碼, 所以時間會長很多, 應該以第二次以及之後的爲準!

 

@timev

用法同@time, 功能增加了 memory, 以及時間精確到 ns。

julia> @timev sqrt.(rand(1000));
  0.000025 seconds (8 allocations: 16.063 KiB)
elapsed time (ns): 25340
bytes allocated:   16448
pool allocs:       6
non-pool GC allocs:2

 

Julia profiler

用 profiler 可以分析出那段代碼佔用了最多時間。

using Profile
using Statistics
function testfunc()
  x = rand(1000000)
  y = std(x)
  return y
end
@profile testfunc()

Profile 會採樣蒐集 profile 信息,獲取信息:

Profile.print();

Profile 的輸出不便於分析, 可以採用 ProfileView 輸出火焰圖(說實話跟 Go 的火焰圖差遠了):

Pkg.add("ProfileView");
using ProfileView
ProfileView.view()

 

BenchmarkTools

https://github.com/JuliaCI/BenchmarkTools.jl

julia> Pkg.add("BenchmarkTools")

julia> using BenchmarkTools

julia> @benchmark sqrt.(rand(1000))
BenchmarkTools.Trial:
  memory estimate:  15.88 KiB
  allocs estimate:  2
  --------------
  minimum time:     4.935 μs (0.00% GC)
  median time:      5.598 μs (0.00% GC)
  mean time:        7.027 μs (11.17% GC)
  maximum time:     229.810 μs (97.54% GC)
  --------------
  samples:          10000
  evals/sample:     7

可以看到,包括內存佔用、內存分配次數,運行時間統計等, 我們後面很多實驗都是用@benchmark。

 

Types

Julia 的類型是在 runtime 檢查的, 但是在 compile time 會生成不同類型的方法。

Julia 中, 函數是一個抽象概念, 一個函數名下可能對應多個具體實現,即方法,比如如下代碼函數 f 有 2 個方法:

julia> function f(x::Int64)
           x
       end
f (generic function with 1 method)

julia> function f(x::String)
           "string"
       end
f (generic function with 2 methods)

運行的時候, 會根據所有參數個數、類型, 選擇最 match 的方法執行, 所以叫 Multiple dispatch。 對比一般的 OO 語言, 只是根據 receiver 來決定 dispatch 哪個方法, 所以叫”single dispatch”。

Julia 中 type 也會形成 hierarchy, 如下圖: julia-type-hierarchy.pnguploading.4e448015.gif轉存失敗重新上傳取消Julia type hierarchy

Julia 中,concrete type 不能有 subtypes, 也就是 final 的!

Any是所有的超類, Nothing是所有的子類, Nothing只有一個實例: nothing

Julia 中的參數類型(類似 Java 的泛型)可以是 value, 比如 Array 的類型爲 Array{T, N}, 其中 N 是具體的數字,表示數組的維數:

julia> typeof([1,2])
Array{Int64,1}

Julia 的類型推斷不是基於著名的 Hindley-Milner 算法(ML 系語言用的類型推斷算法, 比如 Scala), 它只會盡力推, 最後在 runtime 如果找不到 match 的方法就會報錯。

 

Type-stability

Type-stability指的是函數返回值類型只取決於參數類型, 而跟參數的具體值無關。 下面的函數就不符合:

function trunc(x)
  if x < 0
    return 0
  else
    return x
  end
end
julia > trunc(-1.5) |> typeof
Int64

julia > trunc(1.5) |> typeof
Float64

|>是 pipeline 操作符, 把前一個操作的結果傳入下一個操作, 類似於 linux 的|, 可以方便的將f(g(h(j(x))))改寫成可讀性更高的x |> j |> h |> g |> f

要修復trunctype-stability問題, 可以用 zero 方法:

function trunc_fixed(x)
  if x < 0
    return zero(x)
  else
    return x
  end
end
julia> -1.5 |> trunc_fixed |> typeof
Float64

julia> 1.5 |> trunc_fixed |> typeof
Float64

如果函數是 type-unstable 的話, Julia 編譯器沒法編譯出特定類型的優化的代碼, 我們來測試一下:

julia> @benchmark trunc(-2.5)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     0.020 ns (0.00% GC)
  median time:      0.030 ns (0.00% GC)
  mean time:        0.031 ns (0.00% GC)
  maximum time:     8.802 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark trunc_fixed(-2.5)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     0.020 ns (0.00% GC)
  median time:      0.032 ns (0.00% GC)
  mean time:        0.031 ns (0.00% GC)
  maximum time:     8.843 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

1.0 上似乎沒有太大區別了,書裏的 0.4 版本顯示 trunc_fixed 效率是 trunc 的兩倍多, 說明 Julia 本身也在不停地改進以及用新版本是很重要的!

如何識別 type-stability 問題呢? Julia 提供了一個@code_warntype宏:

julia> @code_warntype trunc(2.5)
Body::Union{Float64, Int64}
2 1 ─ %1 = π (0.0, Float64)                                                                                                                                           │╻  <
  │   %2 = (Base.lt_float)(x, %1)::Bool                                                                                                                               ││╻  <
  │   %3 = π (0.0, Float64)                                                                                                                                           ││
  │   %4 = (Base.eq_float)(x, %3)::Bool                                                                                                                               ││╻  ==
  │   %5 = (Base.and_int)(%4, true)::Bool                                                                                                                             ││╻  &
  │   %6 = (Base.and_int)(%5, false)::Bool                                                                                                                            │││
  │   %7 = (Base.or_int)(%2, %6)::Bool                                                                                                                                ││╻  |
  └──      goto #3 if not %7                                                                                                                                          │
3 2 ─      return 0                                                                                                                                                   │
5 3 ─      return x                                                                                                                                                   │

julia> @code_warntype trunc_fixed(2.5)
Body::Float64
2 1 ─ %1 = π (0.0, Float64)                                                                                                                                           │╻  <
  │   %2 = (Base.lt_float)(x, %1)::Bool                                                                                                                               ││╻  <
  │   %3 = π (0.0, Float64)                                                                                                                                           ││
  │   %4 = (Base.eq_float)(x, %3)::Bool                                                                                                                               ││╻  ==
  │   %5 = (Base.and_int)(%4, true)::Bool                                                                                                                             ││╻  &
  │   %6 = (Base.and_int)(%5, false)::Bool                                                                                                                            │││
  │   %7 = (Base.or_int)(%2, %6)::Bool                                                                                                                                ││╻  |
  └──      goto #3 if not %7                                                                                                                                          │
3 2 ─      return 0.0                                                                                                                                                 │
5 3 ─      return x

可以看到 trunc 的返回值類型是 Union{Float64, Int64}。 另外也可以用@code_llvm@code_native兩個宏看函數最後生成的 LLVM IR 指令和機器碼, 會發現 type-stabe 的函數生成的指令也要少一些。

 

函數和宏

 

全局變量的問題

全局變量是 bad smell, 在 Julia 中還會影響性能, 因爲全局變量可能在任何時候被修改爲任何其他類型, 所以 compiler 沒法優化。

julia> p = 2;
julia> function pow_array(x::Vector{Float64})
         s = 0.0
         for y in x
           s = s + y^p
         end
         return s
       end
pow_array (generic function with 1 method)

julia> t = rand(100000);

julia> @benchmark pow_array(t)
BenchmarkTools.Trial:
  memory estimate:  4.58 MiB
  allocs estimate:  300000
  --------------
  minimum time:     7.385 ms (0.00% GC)
  median time:      8.052 ms (0.00% GC)
  mean time:        8.261 ms (2.76% GC)
  maximum time:     50.044 ms (85.05% GC)
  --------------
  samples:          604
  evals/sample:     1

可以將全局變量修改爲const即可:

julia> const p2 = 2
2

julia> function pow_array2(x::Vector{Float64})
         s = 0.0
         for y in x
           s = s + y^p2
         end
         return s
       end
pow_array2 (generic function with 1 method)

julia> @benchmark pow_array2(t)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     81.324 μs (0.00% GC)
  median time:      83.629 μs (0.00% GC)
  mean time:        87.973 μs (0.00% GC)
  maximum time:     185.029 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

性能差距將近 100 倍!!!內存佔用和分配也有很大差別。 用@code_warntype 可以看出兩者的差別, pow_array 返回值是 Any 類型, 而 pow_array2 是 Float64 類型, 可見 pow_array2 是 type-stable 的。

Julia 中的 const 可以修改值!但是不能修改類型!

 

inline

Julia 使用的 LLVM 編譯器, 大部分編譯優化都是 LLVM 的功勞, 不過 inline 是在 LLVM 之前做的。 Julia 有一套啓發式規則, 將“值得 inline”的函數 inline。 inline 會增大 code 的大小,需要權衡。

julia> trunc2(x) = x < 0 ? zero(x) : x
trunc2 (generic function with 1 method)

julia> function sqrt_sin(x)
         y = trunc2(x)
         return sin(sqrt(y)+1)
       end
sqrt_sin (generic function with 1 method)

julia> @code_typed sqrt_sin(-1)
CodeInfo(
2 1 ─ %1  = (Base.slt_int)(x, 0)::Bool                                                                                                                           │╻╷   trunc2
  └──       goto #3 if not %1                                                                                                                                    ││
  2 ─       goto #4                                                                                                                                              ││
  3 ─       goto #4                                                                                                                                              ││
  4 ┄ %5  = φ (#2 => 0, #3 => _2)::Int64                                                                                                                         │
3 │   %6  = (Base.sitofp)(Float64, %5)::Float64                                                                                                                  │╻╷╷╷ sqrt
  │   %7  = (Base.lt_float)(%6, 0.0)::Bool                                                                                                                       ││╻    sqrt
  └──       goto #6 if not %7                                                                                                                                    │││
  5 ─       invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %6::Float64)::Union{}                                                                      │││
  └──       $(Expr(:unreachable))::Union{}                                                                                                                       │││
  6 ─ %11 = (Base.Math.sqrt_llvm)(%6)::Float64                                                                                                                   │││
  └──       goto #7                                                                                                                                              │││
  7 ─       goto #8                                                                                                                                              ││
  8 ─ %14 = (Base.add_float)(%11, 1.0)::Float64                                                                                                                  ││╻    +
  │   %15 = invoke Main.sin(%14::Float64)::Float64                                                                                                               │
  └──       return %15                                                                                                                                           │
) => Float64

從@code_typed 宏的結果可以看出,並沒有調用 trun2,而是直接將 trunc2 的代碼 inline 了。 可以用julia --inline=no啓動 REPL, 結果會是:

julia> @code_typed sqrt_sin(-1)
CodeInfo(
2 1 ─ %1 = (Main.trunc2)(x)::Int64                                                                                                                                          │
3 │   %2 = (Main.sqrt)(%1)::Float64                                                                                                                                         │
  │   %3 = (%2 + 1)::Float64                                                                                                                                                │
  │   %4 = (Main.sin)(%3)::Float64                                                                                                                                          │
  └──      return %4                                                                                                                                                        │
) => Float64

注意, 禁用 inline 會嚴重影響性能, 只在特殊情況下(比如 debugging 或者 code coverage analysis)纔打開。

關掉 inline:

julia> @benchmark sqrt_sin(-1)
BenchmarkTools.Trial:
  memory estimate:  1.45 KiB
  allocs estimate:  77
  --------------
  minimum time:     3.114 μs (0.00% GC)
  median time:      3.410 μs (0.00% GC)
  mean time:        3.749 μs (4.07% GC)
  maximum time:     1.232 ms (99.33% GC)
  --------------
  samples:          10000
  evals/sample:     9

打開 inline:

julia> @benchmark sqrt_sin(-1)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     8.434 ns (0.00% GC)
  median time:      8.682 ns (0.00% GC)
  mean time:        9.083 ns (0.00% GC)
  maximum time:     35.598 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

可以看到, 性能差距巨大!

有時候根據 Julia 的默認規則不會 inline, 可以在 function 定義之前手動加上@inline 則可以, 不過最好是經過 profile 之後確定是代碼熱點纔去做。

julia> function f2(x)
         a=x*5
         a=a*5
         d = a
         a=a*5
         b=a+3
         b=b+3
         b=b+3
         c=a-4
         d=b/c
       end
f2 (generic function with 1 method)

julia> g(x) = f2(2x)
g (generic function with 1 method)

julia> @code_llvm g(3)
define double @julia_g_35286(i64) {
top:
  %1 = mul i64 %0, 250
  %2 = add i64 %1, 9
  %3 = add i64 %1, -4
  %4 = sitofp i64 %2 to double
  %5 = sitofp i64 %3 to double
  %6 = fdiv double %4, %5
  ret double %6
}

我們可以看到 LLVM 生成的代碼中,第一行是%1 = mul i64 %0, 250, 原因在於 f2 中, a 乘以 3 個 5, 並且 g 的定義中有一個 2 倍, 所以 LLVM 直接優化成了乘以2 * 5 * 5 * 5 = 250。 如果沒有 inline 的話,顯然是做不到的(5*5*5應該還是可以優化成*125)。

 

macros

macros 就是在 compile time 用 code 生成 code, 能提前做一些事情, 這樣在 runtime 的時候就能少做一些, 性能自然高了。

書中舉的例子測試無效, 所以暫時不放了。

 

named parameters

有時候 function 參數很多, 可以用 named parameters,可以提高代碼可讀性。 但是性能有點點影響, 大概是 50% 吧。 影響不是太大, 所以建議代碼可讀性和可維護性優先, 只在 performance-sensitive 的內部循環纔不使用。

julia> named_param(x; y=1, z=1)  =  x^y + x^z
named_param (generic function with 1 method)

julia> pos_param(x,y,z) = x^y + x^z
pos_param (generic function with 1 method)

julia> @benchmark named_param(4, y = 2, z = 3)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     6.535 ns (0.00% GC)
  median time:      6.997 ns (0.00% GC)
  mean time:        7.228 ns (0.00% GC)
  maximum time:     43.993 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark pos_param(4, 2, 3)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.367 ns (0.00% GC)
  median time:      4.500 ns (0.00% GC)
  mean time:        4.727 ns (0.00% GC)
  maximum time:     38.259 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章