TensorFlow 讀取圖片並寫入tfrecord

圖片下載地址:
鏈接:https://pan.baidu.com/s/1gvvr5ovcYT1pQTy0umpzrA
提取碼:bh0t

import os
import tensorflow as tf 
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
from PIL import Image
%matplotlib inline
from tqdm import tqdm

print(tf.__version__)
print(np.__version__)
# 讀取文件夾文件與標籤
def load_sample(sample_dir):
    
    print("加載圖片數據")
    file_name_list = []
    labels_names = []
    
    for(dir_path, dir_names, file_names) in os.walk(sample_dir):
        
        for file_name in file_names:
            file_path = os.path.join(dir_path, file_name)
            # 獲取圖片路徑與文件夾名字(標籤)
            file_name_list.append(file_path)    
            labels_names.append(dir_path.split("\\")[-1])
    
    lab = list(sorted(set(labels_names)))
    
    labdict = dict(zip(lab, list(range(len(lab)))))
    labels = [labdict[i] for i in labels_names]
    
    return (np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 讀取文件名與標籤 
data_dir = 'man_woman\\'  # 定義文件路徑

(images, labels), labelsnames = load_sample(data_dir)  # 載入文件名稱與標籤
print(len(images), images)  # 文件名 
print(len(labels), labels)  # 標籤              
print(labelsnames)  # 標籤字符串
def makeTFRec(filenames, labels):  # 定義函數生成TFRecord
    output_dir = "tfrecord_dir"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    filename = "mydata.tfrecords"
    filename_fullpath = os.path.join(output_dir, filename)
    
    with tf.io.TFRecordWriter(filename_fullpath) as writer:
        for i in tqdm(range(0, len(labels))):
            img = Image.open(filenames[i])
            img = img.resize((256, 256))
            img_raw = img.tobytes()  # 將圖片轉化爲二進制格式
            
            features = tf.train.Features(feature = {
                "label":tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[labels[i]])),
                "img_raw":tf.train.Feature(
                    bytes_list = tf.train.BytesList(value=[img_raw]))
            })
            example = tf.train.Example(features=features)  # example對象對label和image數據進行封裝

            writer.write(example.SerializeToString())  # 序列化爲字符串
makeTFRec(images, labels)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章