【Tensorflow】Tensorflow中修改tensor的方法

                                             Tensorflow中修改tensor的方法

在TensorFlow中tensor是不能直接修改數值的,如:

import tensorflow as tf
tensor_1 = tf.constant([x for x in range(1,10)])
# tensor_1 是一個數值爲1到9的張量,希望把中間第五個數值改爲0 
tensor_1[4] = 0 

這時就會報錯,錯誤類型是:

TypeError: 'Tensor' object does not support item assignment

下面總結兩種方法來修改tensor:

1、生成一個新的tensor

# 方法一 : 運用concat函數
tensor_1 = tf.constant([x for x in range(1,10)])
# 將原來的張量拆分爲3部分,修改位置前的部分,要修改的部分和修改位置之後的部分
i = 4
part1 = tensor_1[:i]
part2 = tensor_1[i+1:]
val = tf.constant([0])
new_tensor = tf.concat([part1,val,part2], axis=0)

2、使用one_hot進行加減運算

# 方法二:使用one_hot來進行加減運算
tensor_1 = tf.constant([x for x in range(1,10)])
i = 4
# 生成一個one_hot張量,長度與tensor_1相同,修改位置爲1
shape = tensor_1.get_shape().as_list()
one_hot = tf.one_hot(i,shape[0],dtype=tf.int32)
# 做一個減法運算,將one_hot爲一的變爲原張量該位置的值進行相減
new_tensor = tensor_1 - tensor_1[i] * one_hot

3、使用TensorFlow自帶的assign()函數(修改的tensor必須爲變量(Variable))

import tensorflow as tf

#create a Variable
x=tf.Variable(initial_value=[[1,1],[1,1]],dtype=tf.float32,validate_shape=False)

init_op=tf.global_variables_initializer()
update=tf.assign(x,[[1,2],[1,2]])

with tf.Session() as session:
    session.run(init_op)
    session.run(update)
    x=session.run(x)
    print(x)

tensorflow使用assign(variable,new_value)來更改變量的值,但是真正作用在garph中,必須要調用gpu或者cpu運行這個更新過:session.run(update) 。

 

參考

1、tensorflow更改變量的值

2、Tensorflow小技巧整理:修改張量特定元素的值

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章