圖片下載地址:
鏈接:https://pan.baidu.com/s/1gvvr5ovcYT1pQTy0umpzrA
提取碼:bh0t
import os
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
from PIL import Image
%matplotlib inline
from tqdm import tqdm
print(tf.__version__)
print(np.__version__)
# 讀取文件夾文件與標籤
def load_sample(sample_dir):
print("加載圖片數據")
file_name_list = []
labels_names = []
for(dir_path, dir_names, file_names) in os.walk(sample_dir):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
# 獲取圖片路徑與文件夾名字(標籤)
file_name_list.append(file_path)
labels_names.append(dir_path.split("\\")[-1])
lab = list(sorted(set(labels_names)))
labdict = dict(zip(lab, list(range(len(lab)))))
labels = [labdict[i] for i in labels_names]
return (np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 讀取文件名與標籤
data_dir = 'man_woman\\' # 定義文件路徑
(images, labels), labelsnames = load_sample(data_dir) # 載入文件名稱與標籤
print(len(images), images) # 文件名
print(len(labels), labels) # 標籤
print(labelsnames) # 標籤字符串
def makeTFRec(filenames, labels): # 定義函數生成TFRecord
output_dir = "tfrecord_dir"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filename = "mydata.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
for i in tqdm(range(0, len(labels))):
img = Image.open(filenames[i])
img = img.resize((256, 256))
img_raw = img.tobytes() # 將圖片轉化爲二進制格式
features = tf.train.Features(feature = {
"label":tf.train.Feature(
int64_list=tf.train.Int64List(value=[labels[i]])),
"img_raw":tf.train.Feature(
bytes_list = tf.train.BytesList(value=[img_raw]))
})
example = tf.train.Example(features=features) # example對象對label和image數據進行封裝
writer.write(example.SerializeToString()) # 序列化爲字符串
makeTFRec(images, labels)