非tf.Variable類型的張量需要人爲設置記錄梯度信息

import tensorflow as tf 

# 構建待優化變量
x = tf.constant(1.)
w1 = tf.constant(2.)
b1 = tf.constant(1.)
w2 = tf.constant(2.)
b2 = tf.constant(1.)


with tf.GradientTape(persistent=True) as tape:
	# 非tf.Variable類型的張量需要人爲設置記錄梯度信息
	tape.watch([w1, b1, w2, b2])
	# 構建2層網絡
	y1 = x * w1 + b1	
	y2 = y1 * w2 + b2

# 獨立求解出各個導數
dy2_dy1 = tape.gradient(y2, [y1])[0]
dy1_dw1 = tape.gradient(y1, [w1])[0]
dy2_dw1 = tape.gradient(y2, [w1])[0]

# 驗證鏈式法則
print(dy2_dy1 * dy1_dw1)
print(dy2_dw1)

輸出爲:

#tf.Tensor(2.0, shape=(), dtype=float32)
#tf.Tensor(2.0, shape=(), dtype=float32)

 

如一開始就已經是tf.variable,則不需要tape。watch;如下:

import tensorflow as tf

w = tf.Variable(1.0)
b = tf.Variable(2.0)
x = tf.Variable(3.0)

with tf.GradientTape() as t1:
  with tf.GradientTape() as t2:
    y = x * w + b
  dy_dw, dy_db = t2.gradient(y, [w, b])
d2y_dw2 = t1.gradient(dy_dw, w)

print(dy_dw)
print(dy_db)
print(d2y_dw2)

assert dy_dw.numpy() == 3.0
assert d2y_dw2 is None

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章