最近接觸TensorFlow,需要訓練自己的數據集,看到很多博客資料,瞭解到TensorFlow中自帶的tfrecord文件,但是自己具體實現起來發現自己的情況與資料的一些不太一樣,所以把自己遇到的問題歸納整理出來。新手一枚,水平有限,有許多問題的解決可能僅限於解決,代碼並有優化,有些思路可能走了彎路,希望能跟大家交流。
1.問題1:對於多分類情況,怎麼確定標籤?
(1)多分類:大多資料中給出的是針對兩種分類的情況,採用的是直接用class={class1 , class2}這種格式,但是對於很多類的話,依次寫出類別有點麻煩,那麼可以採用先定義一個列表classes = []來存儲目錄中所有的分類,比如對於字符識別,那麼classes = {1,2,3,4...,A,B,C},然後用for index, name in enumerate(classes)將對應文件夾的名字與整數一一對應起來。
其中enumerate是python中的一個函數,目的會將index和classes中的name對應起來,比如classes = {1,2,3,A,...}那麼index = {1, 2, 3, 4,..}並且與classes中的1,2,3 ,A這些對應。
爲什麼要這樣做?因爲在存入tfrecord的時候,標籤一般用的是整型,當目錄文件中包含A,B,或者字符串的時候要將其變爲整型,我嘗試過讀入tfrecord的時候用tobyte格式,也就是直接用字符串的形式讀入,但是會報錯,也可能是我知識水平不夠,沒有找到正確的方法。
classes = []
for class1 in os.listdir(cwd):
classes.append(class1)
for index, name in enumerate(classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一個圖片的地址
2.問題二:如何在讀入的時候分數據集和測試集(其中測試集佔50%)
(1)我採用的是在逐層訪問文件夾的時候用兩個字典(一對多)存入圖片的標籤和對應圖片地址。
for class1 in os.listdir(cwd):
classes.append(class1)
for index, name in enumerate(classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一個圖片的地址
m += 1
if(m % 5) == 0:
len_testing_dataset += 1
testing_dataset[index].append(img_path)
else:
len_training_dataset += 1
training_dataset[index].append(img_path)
3.問題三:數據格式的變換
(1)從tfrecord中讀出的數據格式是tensor格式,我之前跟着教程構建的帶有計算圖的CNN,它輸入的數據格式和mnist數據集是一
樣的,那麼要將輸出的tensor格式轉化爲與mnist數據集一樣的格式,並且標籤採用one-hot編碼格式
def to_one_hot(classes, label):
num_classes = len(classes)
# print(num_classes)
# print("label-----------",label)
label_arr = np.zeros((num_classes))
# print("label_arr---------",label_arr)
label_arr[label] += 1.0
# print("after change label_arr",label_arr)
return label_arr
def importimg(imagepath,m,classes):
#imagepath爲讀入的圖片tfrecord的地址
#imagepath = "data_train.tfrecords"
# min_after_dequeue = 15
# batch_size = 1
# capacity = min_after_dequeue + 3 * batch_size
# print(imagepath)
# print("m------------",m)
print("開始讀入數據----------------------------------")
filename_queue = tf.train.string_input_producer([imagepath]) #讀入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 取出包含image和label的feature對象
image = tf.decode_raw(features['img_raw'], tf.uint8)
# print("從tfrecord文件中讀取數據image", image)
image = tf.reshape(image, [-1])
# print("after reshape of image-----------------",image)
label = tf.cast(features['label'], tf.int32) # 在流中拋出label張量
with tf.Session() as sess: # 開始一個會話
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
labels = []
images = []
for i in range(m):
print("第", i, "個",imagepath,"數據正在讀取中")
image1, label1 = sess.run([image, label]) # 在會話中取出image和label
image = tf.cast(image1, tf.float32)
label_arr = to_one_hot(classes, label1)
labels.append(label_arr)
images.append(image1)
labels_arr = np.array(labels)
images_arr = np.array(images)
# print("labels_arr------------",labels_arr)
# print("images_arr------------",images_arr)
coord.request_stop()
coord.join(threads)
return images_arr, labels_arr
總的代碼:
import os
import tensorflow as tf
from PIL import Image
from collections import defaultdict
from itertools import groupby
#import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
#讀圖片地址,CNN,已經測試正確
def read_image(cwd):
#m記錄樣本數
m = 0
classes = []
len_testing_dataset = 0
len_training_dataset = 0
training_dataset = defaultdict(list)
testing_dataset = defaultdict(list)
for class1 in os.listdir(cwd):
classes.append(class1)
for index, name in enumerate(classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一個圖片的地址
m += 1
if(m % 5) == 0:
len_testing_dataset += 1
testing_dataset[index].append(img_path)
else:
len_training_dataset += 1
training_dataset[index].append(img_path)
print("training_dataset testing_dataset END ------------------------------------------------------")
return m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset
# m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset = read_image(
# 'E:\datafortest\Testlib1\\'
# )
#CNN,寫數據,已經測試正確
def write_data(dataset, newfilepath):
writer = tf.python_io.TFRecordWriter(newfilepath) # 要生成的文件
for label, img in dataset.items():
for img_path in img:
print("img_path------------",img_path)
img = Image.open(img_path)
img = img.resize((15, 15))
img_raw = img.tobytes() # 將圖片轉化爲二進制格式,uint8
#img_decode = img_raw.decode('utf-8')
#print(img_decode)
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example對象對label和image數據進行封裝
writer.write(example.SerializeToString()) # 序列化爲字符串
print("成功存入tfrecord文件")
writer.close()
# 返回總的類別數,和所有的類別標號
#CNN,寫數據,已經測試正確
def write_mul_data(dataset, newfilepath, record_location):
writer = None
current_index = 0
for label, img in dataset.items():
for img_path in img:
print("img_path------------", img_path)
#每隔10000個就存入一個文件
if current_index % 10000 == 0:
if writer:
writer.close()
record_filename = "{record_location} - {current_index}.tfrecords".format(
record_location = record_location,
current_index = current_index
)
print(record_filename + "----------------------------")
current_index += 1
image_file = tf.read_file(newfilepath)
try:
image = tf.image.decode_jpeg(newfilepath)
except:
print(image_file)
continue
# write_data(training_dataset,"train_set.tfrecords")
def to_one_hot(classes, label):
num_classes = len(classes)
# print(num_classes)
# print("label-----------",label)
label_arr = np.zeros((num_classes))
# print("label_arr---------",label_arr)
label_arr[label] += 1.0
# print("after change label_arr",label_arr)
return label_arr
#CNN,這個方法就是將tensor張量轉化爲images轉化爲int數組和label轉化爲ont-hot編碼
def importimg(imagepath,m,classes):
#imagepath爲讀入的圖片tfrecord的地址
#imagepath = "data_train.tfrecords"
# min_after_dequeue = 15
# batch_size = 1
# capacity = min_after_dequeue + 3 * batch_size
# print(imagepath)
# print("m------------",m)
print("開始讀入數據----------------------------------")
filename_queue = tf.train.string_input_producer([imagepath]) #讀入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 取出包含image和label的feature對象
image = tf.decode_raw(features['img_raw'], tf.uint8)
# print("從tfrecord文件中讀取數據image", image)
image = tf.reshape(image, [-1])
# print("after reshape of image-----------------",image)
label = tf.cast(features['label'], tf.int32) # 在流中拋出label張量
with tf.Session() as sess: # 開始一個會話
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
labels = []
images = []
for i in range(m):
print("第", i, "個",imagepath,"數據正在讀取中")
image1, label1 = sess.run([image, label]) # 在會話中取出image和label
image = tf.cast(image1, tf.float32)
label_arr = to_one_hot(classes, label1)
labels.append(label_arr)
images.append(image1)
labels_arr = np.array(labels)
images_arr = np.array(images)
# print("labels_arr------------",labels_arr)
# print("images_arr------------",images_arr)
coord.request_stop()
coord.join(threads)
return images_arr, labels_arr
參考資料:http://blog.csdn.net/xierhacker/article/details/72357651