tensorflow 一些tips

1. tensorflow不支持不同秩的張量和矩陣相乘 

例如:shape(3,2,3)和 (3,4)的tensor不能相乘 ,以前可以用batch_matmul() 新的版本中已被刪除

但是神經網絡中有時需要這類的操作,因爲第一個維度往往是樣本數維度。以下方法可以實現上述操作:

利用tf.scan() 實現

        # A維度(batch_size,dim_1,dim_2) B維度 (dim_2,dim_3) ——>返回(batch_size,dim_1,dim_3)
        def batch_matmul(A, B):
            # self.units 和self.fre_dim 對應(dim_1,dim_3) 
            initializer = tf.ones([self.units,self.fre_dim], dtype=self.dtype)
            C = tf.scan(lambda a, x: tf.matmul(x, B), A, initializer)
            return C

2.一個方便劃分訓練和測試集的函數

from sklearn.model_selection import train_test_split
 train_X, test_X, train_y, test_y = train_test_split(dataX, dataY,
                                                        test_size=0.2,
                                                        random_state= np.random.seed(1000))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章