Tensorflow中修改tensor的方法
在TensorFlow中tensor是不能直接修改數值的,如:
import tensorflow as tf
tensor_1 = tf.constant([x for x in range(1,10)])
# tensor_1 是一個數值爲1到9的張量,希望把中間第五個數值改爲0
tensor_1[4] = 0
這時就會報錯,錯誤類型是:
TypeError: 'Tensor' object does not support item assignment
下面總結兩種方法來修改tensor:
1、生成一個新的tensor
# 方法一 : 運用concat函數
tensor_1 = tf.constant([x for x in range(1,10)])
# 將原來的張量拆分爲3部分,修改位置前的部分,要修改的部分和修改位置之後的部分
i = 4
part1 = tensor_1[:i]
part2 = tensor_1[i+1:]
val = tf.constant([0])
new_tensor = tf.concat([part1,val,part2], axis=0)
2、使用one_hot進行加減運算
# 方法二:使用one_hot來進行加減運算
tensor_1 = tf.constant([x for x in range(1,10)])
i = 4
# 生成一個one_hot張量,長度與tensor_1相同,修改位置爲1
shape = tensor_1.get_shape().as_list()
one_hot = tf.one_hot(i,shape[0],dtype=tf.int32)
# 做一個減法運算,將one_hot爲一的變爲原張量該位置的值進行相減
new_tensor = tensor_1 - tensor_1[i] * one_hot
3、使用TensorFlow自帶的assign()函數(修改的tensor必須爲變量(Variable))
import tensorflow as tf
#create a Variable
x=tf.Variable(initial_value=[[1,1],[1,1]],dtype=tf.float32,validate_shape=False)
init_op=tf.global_variables_initializer()
update=tf.assign(x,[[1,2],[1,2]])
with tf.Session() as session:
session.run(init_op)
session.run(update)
x=session.run(x)
print(x)
tensorflow使用assign(variable,new_value)來更改變量的值,但是真正作用在garph中,必須要調用gpu或者cpu運行這個更新過:session.run(update) 。
參考