聲明:本文僅提供瞭如何生成k折數據集。並沒有提供實現k折的具體過程。避免有人罵我,提供一個K折的例子。個人感覺吧。。。挺簡單的,用tf.Keras實現的。和keras很像,可惜我keras太弱。。雖然模型可以看懂,但是很多邏輯我有點雲裏霧裏。比如如如如如如如如如如。。。不說也罷。
附鏈接:https://blog.csdn.net/coolyuan/article/details/104276183
(1)K折交叉驗證生成TFrecords
import os
import tensorflow as tf
import numpy as np
from PIL import Image
from sklearn.model_selection import StratifiedKFold
# ---------------------------------------參數信息---------------------------------------------------
# region一、參數信息
"""上次生成數據信息
0 Beach
1 Farmland
2 Mountain
3 River
4 Bridge
5 Forest
6 Meadow
樣本總數:7068
訓練集數量:702
測試集總量:702
驗證集總量:572
"""
# 原始圖片的存儲位置
orig_picture = 'D:/1/XiongAnDatasets/AID_1'
# 需要的識別類型
classes = {'Bridge', 'Meadow', 'River', 'Mountain', 'Beach', 'Farmland', 'Forest'}
# 將圖片尺寸大小統一
new_height = 200
new_width = 200
new_channels = 3
# 訓練集和測試集存放路徑
TF_train = "D:/1/tf_file/train_"
TF_test = "D:/1/tf_file/test_"
TF_valid = "D:/1/tf_file/valid_"
# K(10)折交叉驗證
n_splits = 10
# 數據集劃分比例
radio = 0.9
# 記錄數據
All_examples = 0
Train_examples = 0
Test_examples = 0
Valid_examples = 0
# endregion
# ----------------------------------函數信息--------------------------------------------
# region二、函數區
"""說明
(1)get_files(): 數據讀取並打亂
(2)create_record(): 製作TFRecords
"""
# region(1)讀取數據(順序已經打亂)
def get_files():
class_train = []
label_train = []
for index, name in enumerate(classes):
print(index, name)
class_path = orig_picture + '/' + name
for pic in os.listdir(class_path):
class_train.append(class_path + '/' + pic)
label_train.append(index)
temp = np.array([class_train, label_train])
temp = temp.transpose()
# shuffle the samples
np.random.shuffle(temp)
# after transpose, images is in dimension 0 and label in dimension 1
image_list = list(temp[:, 0])
label_list = list(temp[:, 1])
return image_list, label_list
# endregion
# region(3)製作TFRecords數據
def create_record(img_list, lab_list, path):
# region1、寫出路徑
with tf.python_io.TFRecordWriter(path) as writer:
# 已知數據是一一對應的。所以利用同一個for i in range(len(image_list))將數據轉換爲樣本
for i in range(len(img_list)):
img = Image.open(img_list[i])
img = img.resize((new_width, new_height), Image.ANTIALIAS)
image_val = img.tobytes() # 將圖片轉化爲原生bytes
label_val = int(lab_list[i])
# 創建example實例
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_val])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_val])),
}))
# 序列化example對象,並寫入到文件
writer.write(example.SerializeToString())
return None
# endregion
# endregion
# ---------------------------------------執行---------------------------------------------------
if __name__ == '__main__':
# region(1)劃分數據集
# region(1.1)準備images、labels
images_list, labels_list = get_files()
All_examples = len(images_list)
# endregion
# region(1.2)開始劃分
skf = StratifiedKFold(n_splits=n_splits)
split_i = 0
for train_index, test_index in skf.split(images_list, labels_list):
split_i = split_i + 1
# print("Train Index:", train_index, ",Test Index:", test_index)
X_train, X_test = np.array(images_list)[0:int(len(train_index) * radio)], np.array(images_list)[test_index]
Y_train, Y_test = np.array(labels_list)[0:int(len(train_index) * radio)], np.array(labels_list)[test_index]
X_valid, Y_valid = np.array(X_train)[int(len(X_train) * radio):-1], np.array(Y_train)[
int(len(Y_train) * radio):-1]
Train_examples = len(X_train)
Test_examples = len(X_test)
Valid_examples = len(X_valid)
# region(1.2.1)生成10份TFrecords
create_record(X_train, Y_train, TF_train+str(split_i)+".tfrecords")
create_record(X_test, Y_test, TF_test+str(split_i)+".tfrecords")
create_record(X_valid, Y_valid, TF_valid+str(split_i)+".tfrecords")
print(str(split_i)+".tfrecords文件生成成功!")
# endregion
print("樣本總數:\n", All_examples)
print("訓練集數量:\n", Test_examples)
print("測試集總量:\n", Test_examples)
print("驗證集總量:\n", Valid_examples)
print("(n_splits折劃分的結果(訓練集、驗證集和測試集的數量)並不一定每次都相同(可以通過"
"將上述打印縮進到上面的for循環中查看)。但是上下最多錯5個左右,所以在訓練時,"
"建議train_batch_size>20且驗證集和測試集的batch_size=1,即每次取出來一個,"
"預測結果對比真實值,如果相同,true_num+=1,"
"最終準確率爲:true_num/len(驗證集或者測試集的總數)")
print("每次生成時,類別以及類別對應的編號會被重新打亂,我在運行時控制檯的開頭將該信息進行了打印輸出,"
"建議記錄下來。方便後期驗證。")
# endregion
(2)非交叉驗證生成TFrecords
修改下主函數就行
if __name__ == '__main__':
# region(1)劃分數據集
# region(1.1)準備images、labels
images_list, labels_list = get_files()
All_examples = len(images_list)
# endregion
X_train, X_test = np.array(images_list)[0:int(All_examples * radio)], \
np.array(images_list)[0:int(All_examples * (1 - radio))]
Y_train, Y_test = np.array(labels_list)[0:int(All_examples * radio)], \
np.array(labels_list)[0:int(All_examples * (1 - radio))]
X_valid, Y_valid = np.array(X_train)[int(len(X_train) * radio):-1], \
np.array(Y_train)[int(len(Y_train) * radio):-1]
Train_examples = len(X_train)
Test_examples = len(X_test)
Valid_examples = len(X_valid)