Tensorflow相關命令

1.tf.split和tf.unstack()

tf.split(value,num_or_size_splits,axis=0,num=None,name=’split’)
函數參數:
value:要分割的 Tensor。
num_or_size_splits:指示沿 split_dim 分割數量的 0-D 整數 Tensor 或包含沿 split_dim 每個輸出張量大小的 1-D 整數 Tensor ;如果爲一個標量,那麼它必須均勻分割 value.shape[axis];否則沿分割維度的大小總和必須與該 value 相匹配。
axis:A 0-D int32 Tensor;表示分割的尺寸;必須在[-rank(value), rank(value))範圍內;默認爲0。
num:可選的,用於指定無法從 size_splits 的形狀推斷出的輸出數。
name:操作的名稱(可選)。
函數返回值:
如果 num_or_size_splits 是標量,返回 num_or_size_splits Tensor對象;
如果 num_or_size_splits 是一維張量,則返回由 value 分割產生的 num_or_size_splits.get_shape[0] Tensor對象

import tensorflow as tf

A = [[1, 2, 3], [4, 5, 6]]
a0 = tf.split(A, num_or_size_splits=3, axis=1)
a1 = tf.unstack(A, num=3,axis=1)
a2 = tf.split(A, num_or_size_splits=2, axis=0)
a3 = tf.unstack(A, num=2,axis=0)
with tf.Session() as sess:
    print(sess.run(a0))
    print(sess.run(a1))
    print(sess.run(a2))
    print(sess.run(a3))

返回結果:
[array([[1],[4]]), array([[2], [5]]), array([[3],[6]])]
[array([1, 4]), array([2, 5]), array([3, 6])]
[array([[1, 2, 3]]), array([[4, 5, 6]])]
[array([1, 2, 3]), array([4, 5, 6])]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章