今天寫了如下的代碼,用來測試手勢識別的神經網絡算法準確性:
from skimage import io,transform
import tensorflow as tf
import numpy as np
import os
path = "./Images/"
dict = {0:'palm',1:'l',2:'fist',3:'fist_move',4:'thumb',5:'index',6:'ok',7:'palm_move',8:'c',9:'down'}
w=240
h=320
c=1
def read_one_image(path):
img = io.imread(path)
img = transform.resize(img,(w,h))
return np.asarray(img)
with tf.Session() as sess:
data = []
for images in os.listdir(path):
#data1 = read_one_image(os.path.join(path,images))
data1 = read_one_image(path + images)
data.append(data1)
#data = np.array(data).reshape(-1,w,h,1)
saver = tf.train.import_meta_graph('./modelSave/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('./modelSave/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
feed_dict = {x:data}
logits = graph.get_tensor_by_name("logits_eval:0")
classification_result = sess.run(logits,feed_dict)
# 打印出預測矩陣
print(classification_result)
# 打印出預測矩陣每一行最大值的索引
print(tf.argmax(classification_result, 1).eval())
# 根據索引通過字典對應花的分類
output = []
output = tf.argmax(classification_result, 1).eval()
for i in range(len(output)):
print("第",i+1,"個手勢預測:"+dict[output[i]])
可是運行的時候卻報錯:ValueError: Cannot feed value of shape (9, 240, 320) for Tensor 'x:0', which has shape '(?, 240, 320, 1)'
原來是data的維度和Tensor x不匹配,添加紅色部分的那一句data = np.array(data).reshape(-1,w,h,1) ,
就可以正常運行了,輸出的結果如下:
...
[1 4 6 9 0 0 1 2 4]
第 1 個手勢預測:l
第 2 個手勢預測:thumb
第 3 個手勢預測:ok
第 4 個手勢預測:down
第 5 個手勢預測:palm
第 6 個手勢預測:palm
第 7 個手勢預測:l
第 8 個手勢預測:fist
第 9 個手勢預測:thumb