包裝用於TensorFlow操作的Python函數

最近使用tensoflow構建神經網絡模型時遇到一個問題:
我們知道,tensorflow是一種計算圖模型,即用圖的形式來表示計算過程的一種模型。程序一般分爲圖的構建和圖的執行兩個階段。圖的構建階段即圖的定義階段,會利用佔位符tf.placeholder聲明數據的格式和位置,用於傳入數據到計算圖中,執行階段纔會傳入數據運行圖模型。

我在構建網絡模型時需要用到非神經網絡的函數,這個函數暫時在tensorflow裏並未有專門的實現,而是python的pywt庫的一個離散小波變換函數,這個函數的輸入並不是tensor,而是array。但是在構建圖的階段只能有佔位符表示的tensor數據,這就導致我在構建圖的過程中調用pywt.dwtn()函數傳入tensor數據時直接報錯:
TypeError: Input must be a numeric array-like.

一開始想着怎麼能將tensor轉換爲array類型呢,一種轉換方式是用data.eval(),但是.eval()必須在session裏面才能用,會報錯:ValueError: Cannot evaluate tensor using eval(): No default session is registered. Use with sess.as_default() or pass an explicit session to eval(session=sess).
如果在這裏使用一個session又會和執行階段的session衝突。

後來查資料查到包裝用於TensorFlow操作的Python函數,https://www.w3cschool.cn/tensorflow_python/tensorflow_python-7eil2g75.html

採用的解決辦法是:
定義一個函數:
def wavelet3D(self, in, reuse):
coffes = pywt.dwtn(np.array(in), wavelet=’haar’, mode=’symmetric’,axes=None)
return coffes
調用函數時:
coffes2 = tf.py_func(self.wavelet3D,[xt, reuse],[tf.float32])
self.wavelet3D是定義的函數名,[xt,reuse]是輸入,[tf.float32]是輸出的數據類型。
這樣纔可以。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章