理解Theano的Scan函數

1 Scan是幹什麼的

函數scan是Theano中迭代的一般形式,所以可以用於類似循環(looping)的場景。
如果你熟悉Reduction和map兩個函數,這兩個都是scan的特殊形式,即將某函數依次作用一個序列的每個元素上。
函數scan的輸入也是一些序列(一維數組,或者多維數組,以第一維爲leading dimension),將某個函數作用於輸入序列上,得到每一步輸出的結果。
和Reduction和map兩個函數不同之處在於,scan在計算的時候,可以訪問以前n步的輸出結果,所以比較適合RNN網絡。

2 爲什麼要使用scan

看起來scan完全可以用for… loop來代替,然而scan有其自身的優點:

  • 由於Theano是使用符號代數的,迭代的次數就自然成爲符號代數的一部分。也就是說迭代次數也會體現在構造符號代數的圖中。
    (Theano用一個圖來表示符號代數)

  • 由於上面一條,可以直接用Theano計算梯度。

  • 優化減少CPU和GPU之間的數據傳輸,比Python Loop稍微快一點。

  • 說不定以後還會有符號代數的其他優點,例如自動優化 y = x/x*x。

3 大概參數說明

函數scan調用的一般形式的一個例子大概是這樣:

results, updates = theano.scan(fn = lambda y, p, x_tm2, x_tm1,A: y+p+x_tm2+xtm1+A,
sequences=[Y, P[::-1]], 
outputs_info=[dict(initial=X, taps=[-2, -1])]), 
non_sequences=A)

*參數fn是一個你需要計算的函數,一般用函數定義(比較簡單的可以用lambda來定義),參數是有順序要求的,先是sequences的參數(y,p),然後是output_info的參數(x_tm2,x_tm1),然後是no_sequences的參數(A)。

*sequences就是需要迭代的序列,序列的第一個維度(leading dimension)就是需要迭代的次數。所以,Y和P[::-1]的第一維大小應該相同,如果不同的話,就會取最小的。

*outputs_info描述了需要用到前幾次迭代輸出的結果,dict(initial=X, taps=[-2, -1])表示使用前一次和前兩次輸出的結果。如果當前迭代輸出爲x(t),則計算中使用了x(t-1)和x(t-2)。 所以output_info中的dictionary個數應該和fn的輸出個數是對應的。

*non_sequences描述了非序列的輸入,即A是一個固定的輸入,每次迭代加的A都是相同的。如果Y是一個向量,A就是一個常數。總之,A比Y少一個維度。

4 舉例

計算 Ak , 大材小用一下

k = T.iscalar("k")
A = T.vector("A")

# Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)

# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
final_result = result[-1]
# compiled function that returns A**k
power = theano.function(inputs=[A,k], outputs=final_result, updates=updates)

print power(range(10),2)
print power(range(10),4)

輸出:

[  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]
[  0.00000000e+00   1.00000000e+00   1.60000000e+01   8.10000000e+01
   2.56000000e+02   6.25000000e+02   1.29600000e+03   2.40100000e+03
   4.09600000e+03   6.56100000e+03]

計算 Computing tanh(x(t).dot(W)+b)

X = T.matrix("X")
W = T.matrix("W")
b_sym = T.vector("b_sym")

results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function(inputs=[X, W, b_sym], outputs=[results])

# test values
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2
print compute_elementwise(x, w, b)[0]
# comparison with numpy
print np.tanh(x.dot(w) + b)

輸出:

[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]
[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章