製作自己的數據集tfrecord格式

最近接觸TensorFlow,需要訓練自己的數據集,看到很多博客資料,瞭解到TensorFlow中自帶的tfrecord文件,但是自己具體實現起來發現自己的情況與資料的一些不太一樣,所以把自己遇到的問題歸納整理出來。新手一枚,水平有限,有許多問題的解決可能僅限於解決,代碼並有優化,有些思路可能走了彎路,希望能跟大家交流。

1.問題1:對於多分類情況,怎麼確定標籤?

(1)多分類:大多資料中給出的是針對兩種分類的情況,採用的是直接用class={class1 , class2}這種格式,但是對於很多類的話,依次寫出類別有點麻煩,那麼可以採用先定義一個列表classes = []來存儲目錄中所有的分類,比如對於字符識別,那麼classes = {1,2,3,4...,A,B,C},然後用for index, name in enumerate(classes)將對應文件夾的名字與整數一一對應起來。

其中enumerate是python中的一個函數,目的會將index和classes中的name對應起來,比如classes = {1,2,3,A,...}那麼index = {1, 2, 3, 4,..}並且與classes中的1,2,3 ,A這些對應。

爲什麼要這樣做?因爲在存入tfrecord的時候,標籤一般用的是整型,當目錄文件中包含A,B,或者字符串的時候要將其變爲整型,我嘗試過讀入tfrecord的時候用tobyte格式,也就是直接用字符串的形式讀入,但是會報錯,也可能是我知識水平不夠,沒有找到正確的方法。

classes = []
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一個圖片的地址

2.問題二:如何在讀入的時候分數據集和測試集(其中測試集佔50%)
(1)我採用的是在逐層訪問文件夾的時候用兩個字典(一對多)存入圖片的標籤和對應圖片地址。
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一個圖片的地址
            m += 1
            if(m % 5) == 0:
                len_testing_dataset += 1
                testing_dataset[index].append(img_path)
            else:
                len_training_dataset += 1
                training_dataset[index].append(img_path)

3.問題三:數據格式的變換
(1)從tfrecord中讀出的數據格式是tensor格式,我之前跟着教程構建的帶有計算圖的CNN,它輸入的數據格式和mnist數據集是一
樣的,那麼要將輸出的tensor格式轉化爲與mnist數據集一樣的格式,並且標籤採用one-hot編碼格式
def to_one_hot(classes, label):
    num_classes = len(classes)
    # print(num_classes)
    # print("label-----------",label)
    label_arr = np.zeros((num_classes))
    # print("label_arr---------",label_arr)
    label_arr[label] += 1.0
    # print("after change label_arr",label_arr)
    return label_arr

def importimg(imagepath,m,classes):
    #imagepath爲讀入的圖片tfrecord的地址
    #imagepath = "data_train.tfrecords"
    # min_after_dequeue = 15
    # batch_size = 1
    # capacity = min_after_dequeue + 3 * batch_size
    # print(imagepath)
    # print("m------------",m)
    print("開始讀入數據----------------------------------")
    filename_queue = tf.train.string_input_producer([imagepath]) #讀入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature對象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    # print("從tfrecord文件中讀取數據image", image)
    image = tf.reshape(image, [-1])
    # print("after reshape of image-----------------",image)
    label = tf.cast(features['label'], tf.int32)  # 在流中拋出label張量
    with tf.Session() as sess:  # 開始一個會話
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        labels = []
        images = []
        for i in range(m):
            print("第", i, "個",imagepath,"數據正在讀取中")
            image1, label1 = sess.run([image, label])  # 在會話中取出image和label
            image = tf.cast(image1, tf.float32)
            label_arr = to_one_hot(classes, label1)
            labels.append(label_arr)
            images.append(image1)
            labels_arr = np.array(labels)
            images_arr = np.array(images)
        # print("labels_arr------------",labels_arr)
        # print("images_arr------------",images_arr)
        coord.request_stop()
        coord.join(threads)
    return images_arr, labels_arr

總的代碼:
import os
import tensorflow as tf
from PIL import Image
from collections import defaultdict
from itertools import groupby
#import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np

#讀圖片地址,CNN,已經測試正確
def read_image(cwd):
    #m記錄樣本數
    m = 0
    classes = []
    len_testing_dataset = 0
    len_training_dataset = 0
    training_dataset = defaultdict(list)
    testing_dataset = defaultdict(list)
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一個圖片的地址
            m += 1
            if(m % 5) == 0:
                len_testing_dataset += 1
                testing_dataset[index].append(img_path)
            else:
                len_training_dataset += 1
                training_dataset[index].append(img_path)
    print("training_dataset testing_dataset END ------------------------------------------------------")
    return m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset

# m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset = read_image(
#      'E:\datafortest\Testlib1\\'
# )


#CNN,寫數據,已經測試正確
def write_data(dataset, newfilepath):
    writer = tf.python_io.TFRecordWriter(newfilepath)  # 要生成的文件
    for label, img in dataset.items():
        for img_path in img:
            print("img_path------------",img_path)
            img = Image.open(img_path)
            img = img.resize((15, 15))
            img_raw = img.tobytes()  # 將圖片轉化爲二進制格式,uint8
            #img_decode = img_raw.decode('utf-8')
            #print(img_decode)
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example對象對label和image數據進行封裝
            writer.write(example.SerializeToString())  # 序列化爲字符串
    print("成功存入tfrecord文件")
    writer.close()
    # 返回總的類別數,和所有的類別標號

#CNN,寫數據,已經測試正確
def write_mul_data(dataset, newfilepath, record_location):
    writer = None
    current_index = 0
    for label, img in dataset.items():
        for img_path in img:
            print("img_path------------", img_path)
            #每隔10000個就存入一個文件
            if current_index % 10000 == 0:
                if writer:
                    writer.close()
                record_filename = "{record_location} - {current_index}.tfrecords".format(
                    record_location = record_location,
                    current_index = current_index
                )
                print(record_filename + "----------------------------")
            current_index += 1
            image_file = tf.read_file(newfilepath)
            try:
                image = tf.image.decode_jpeg(newfilepath)
            except:
                print(image_file)
                continue






# write_data(training_dataset,"train_set.tfrecords")

def to_one_hot(classes, label):
    num_classes = len(classes)
    # print(num_classes)
    # print("label-----------",label)
    label_arr = np.zeros((num_classes))
    # print("label_arr---------",label_arr)
    label_arr[label] += 1.0
    # print("after change label_arr",label_arr)
    return label_arr

#CNN,這個方法就是將tensor張量轉化爲images轉化爲int數組和label轉化爲ont-hot編碼
def importimg(imagepath,m,classes):
    #imagepath爲讀入的圖片tfrecord的地址
    #imagepath = "data_train.tfrecords"
    # min_after_dequeue = 15
    # batch_size = 1
    # capacity = min_after_dequeue + 3 * batch_size
    # print(imagepath)
    # print("m------------",m)
    print("開始讀入數據----------------------------------")
    filename_queue = tf.train.string_input_producer([imagepath]) #讀入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature對象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    # print("從tfrecord文件中讀取數據image", image)
    image = tf.reshape(image, [-1])
    # print("after reshape of image-----------------",image)
    label = tf.cast(features['label'], tf.int32)  # 在流中拋出label張量
    with tf.Session() as sess:  # 開始一個會話
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        labels = []
        images = []
        for i in range(m):
            print("第", i, "個",imagepath,"數據正在讀取中")
            image1, label1 = sess.run([image, label])  # 在會話中取出image和label
            image = tf.cast(image1, tf.float32)
            label_arr = to_one_hot(classes, label1)
            labels.append(label_arr)
            images.append(image1)
            labels_arr = np.array(labels)
            images_arr = np.array(images)
        # print("labels_arr------------",labels_arr)
        # print("images_arr------------",images_arr)
        coord.request_stop()
        coord.join(threads)
    return images_arr, labels_arr

參考資料:http://blog.csdn.net/xierhacker/article/details/72357651


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章