Tensorflow: 怎麼寫L2 loss

最近真的寫的都是非常基礎的小點,也是希望自己能夠好好積累。

在訓練中,我們通常會定義各種各樣的loss,那麼最簡單的一個loss就是L2的distance,這個要怎麼寫呢?

假設我們有兩個feature,一個是我們predict出來的,我們假設爲pred,另外一個是我們得到的ground_truth, 假設爲gt,

pred = np.random.randint(1, 5, [2, 5, 5, 3])
print("pred", pred.transpose(0,3, 1,2))
gt = np.random.randint(2, 8, [2, 5, 5, 3])
print("pred", pred.transpose(0,3, 1,2))

tf_square = tf.square(pred-gt)
tf_reduce_sum = tf.reduce_sum(tf_square, axis=3, keepdims=True)
tf_reduce_sum_2 = tf.reduce_sum(tf_reduce_sum, axis=[1, 2, 3])
tf_reduce_mean = tf.reduce_mean(tf_reduce_sum_2, axis=0)
sess =tf.Session()
print("tf_square: ", np.squeeze(sess.run(tf_square)))
print("tf_reduce_sum: ", np.squeeze(sess.run(tf_reduce_sum)))
print("tf_reduce_sum_2: ", np.squeeze(sess.run(tf_reduce_sum_2)))
print("tf_reduce_mean: ", np.squeeze(sess.run(tf_reduce_mean)))

# use the tensorflow op
res = tf.nn.l2_loss((pred-gt).astype(np.float32))
print("tf.nn.l2_loss: ", np.squeeze(sess.run(res)))

寫的比較複雜了,其實可以寫在一行裏面,但是看的可能不是很清楚。

下面是展現的結果,其實這些只要自己用個簡單的小例子,就能理解的很透徹了

pred [[[[2 1 1 4 1]
   [3 1 4 2 3]
   [2 3 2 3 2]
   [2 2 4 4 3]
   [3 1 2 4 4]]

  [[2 2 2 1 3]
   [4 2 2 2 2]
   [3 2 4 3 2]
   [2 2 4 2 3]
   [3 3 2 3 2]]

  [[3 2 1 4 2]
   [3 3 2 3 2]
   [2 4 4 3 1]
   [2 3 1 3 1]
   [3 2 1 4 4]]]


 [[[4 3 2 1 2]
   [2 3 1 2 3]
   [4 4 4 2 2]
   [3 1 4 4 1]
   [4 1 1 3 4]]

  [[2 3 2 4 2]
   [3 3 2 4 3]
   [3 1 4 3 1]
   [3 3 3 4 3]
   [4 4 2 2 1]]

  [[3 4 4 2 2]
   [3 3 2 1 2]
   [4 2 1 2 1]
   [1 2 4 2 3]
   [1 2 1 2 3]]]]
pred [[[[2 1 1 4 1]
   [3 1 4 2 3]
   [2 3 2 3 2]
   [2 2 4 4 3]
   [3 1 2 4 4]]

  [[2 2 2 1 3]
   [4 2 2 2 2]
   [3 2 4 3 2]
   [2 2 4 2 3]
   [3 3 2 3 2]]

  [[3 2 1 4 2]
   [3 3 2 3 2]
   [2 4 4 3 1]
   [2 3 1 3 1]
   [3 2 1 4 4]]]


 [[[4 3 2 1 2]
   [2 3 1 2 3]
   [4 4 4 2 2]
   [3 1 4 4 1]
   [4 1 1 3 4]]

  [[2 3 2 4 2]
   [3 3 2 4 3]
   [3 1 4 3 1]
   [3 3 3 4 3]
   [4 4 2 2 1]]

  [[3 4 4 2 2]
   [3 3 2 1 2]
   [4 2 1 2 1]
   [1 2 4 2 3]
   [1 2 1 2 3]]]]
tf_square:  [[[[ 0  4  4]
   [16  0  0]
   [ 9  0 36]
   [ 0  9  4]
   [ 9 16  9]]

  [[ 0  4  1]
   [25  9  1]
   [ 1  9  4]
   [25  1  1]
   [ 9  9 25]]

  [[25  0 25]
   [ 9  4  9]
   [25  1  4]
   [ 1  1  0]
   [ 4 25 25]]

  [[ 9  1  1]
   [16 25  1]
   [ 0  0 25]
   [ 4 16  9]
   [ 1  4  1]]

  [[ 1  1  0]
   [ 4  9  4]
   [25 25  1]
   [ 1  1  9]
   [ 1  1  4]]]


 [[[ 0  1  0]
   [ 4  9  4]
   [16 16  1]
   [ 9  4  1]
   [ 9  1  9]]

  [[25  0  9]
   [ 4  4  0]
   [ 9 16  1]
   [ 0  0 36]
   [16 16  0]]

  [[ 0  4  9]
   [ 1  4  1]
   [ 0  1 16]
   [ 9  9  1]
   [ 0 36 16]]

  [[ 9  1 25]
   [16  0  1]
   [ 4 16  1]
   [ 0  1  4]
   [36  0  9]]

  [[ 4  4 25]
   [25  4  4]
   [25  9 25]
   [16  0 25]
   [ 4  9  9]]]]
tf_reduce_sum:  [[[ 8 16 45 13 34]
  [ 5 35 14 27 43]
  [50 22 30  2 54]
  [11 42 25 29  6]
  [ 2 17 51 11  6]]

 [[ 1 17 33 14 19]
  [34  8 26 36 32]
  [13  6 17 19 52]
  [35 17 21  5 45]
  [33 33 59 41 22]]]
tf_reduce_sum_2:  [598 638]
tf_reduce_mean:  618
tf.nn.l2_loss:  618.0

可以看到我們寫了好多行,使用tf的接口一行就搞定啦!

(⊙v⊙)嗯,今日份小總結,水平有限,能力一般。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章