包装用于TensorFlow操作的Python函数

最近使用tensoflow构建神经网络模型时遇到一个问题:
我们知道,tensorflow是一种计算图模型,即用图的形式来表示计算过程的一种模型。程序一般分为图的构建和图的执行两个阶段。图的构建阶段即图的定义阶段,会利用占位符tf.placeholder声明数据的格式和位置,用于传入数据到计算图中,执行阶段才会传入数据运行图模型。

我在构建网络模型时需要用到非神经网络的函数,这个函数暂时在tensorflow里并未有专门的实现,而是python的pywt库的一个离散小波变换函数,这个函数的输入并不是tensor,而是array。但是在构建图的阶段只能有占位符表示的tensor数据,这就导致我在构建图的过程中调用pywt.dwtn()函数传入tensor数据时直接报错:
TypeError: Input must be a numeric array-like.

一开始想着怎么能将tensor转换为array类型呢,一种转换方式是用data.eval(),但是.eval()必须在session里面才能用,会报错:ValueError: Cannot evaluate tensor using eval(): No default session is registered. Use with sess.as_default() or pass an explicit session to eval(session=sess).
如果在这里使用一个session又会和执行阶段的session冲突。

后来查资料查到包装用于TensorFlow操作的Python函数,https://www.w3cschool.cn/tensorflow_python/tensorflow_python-7eil2g75.html

采用的解决办法是:
定义一个函数:
def wavelet3D(self, in, reuse):
coffes = pywt.dwtn(np.array(in), wavelet=’haar’, mode=’symmetric’,axes=None)
return coffes
调用函数时:
coffes2 = tf.py_func(self.wavelet3D,[xt, reuse],[tf.float32])
self.wavelet3D是定义的函数名,[xt,reuse]是输入,[tf.float32]是输出的数据类型。
这样才可以。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章