tf.linalg.svd踩坑

np.linalg.svd具體形式。注意u,s,v的位置 

>>> import numpy as np

>>> M = np.mat([[1,2,3,4],[5,6,7,8],[2,3,4,5]])
>>> u, s, v = np.linalg.svd(M)

>>> print(M.shape)
(3, 4)
>>> print(u.shape)
(3, 3)
>>> print(s.shape)
(3,)   # 根據矩陣乘法,這裏實際表示的對角陣shape是(3, 4)
>>> print(v.shape)
(4, 4)

# 驗證一下公式
>>> print(u.dot(np.column_stack((diag(s), np.zeros(3))).dot(v)))
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 2.  3.  4.  5.]]

 首先請注意它返回值中的usv矩陣順序,這是新手很容易被坑的地方。

>>> import tensorflow as tf
>>> tf.InteractiveSession()
>>> M = tf.constant([[1,2,3,4],[5,6,7,8],[2,3,4,5]], dtype=tf.float32)
>>> s, u, v = tf.svd(M)

>>> print(M.shape)
(3, 4)
>>> print(u.shape)
(3, 3)
>>> print(s.shape)
(3,)   # 根據矩陣乘法,這裏實際表示的對角陣shape是(3, 3)
>>> print(v.shape)
(4, 3)

排查了一下午,才找到這個問題,因爲tensorflow的某些向量是不能直接顯示出來的。所以只能一點一點排查。希望對新手有所幫助。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章