前言:
最近在學習tf2
數據加載感覺蠻方便的
這裏記錄下使用 tf.data.Dataset.from_tensor_slices
進行加載數據集.
使用tf2做mnist(kaggle)的代碼
思路
Step0: 準備要加載的numpy數據
Step1: 使用 tf.data.Dataset.from_tensor_slices()
函數進行加載
Step2: 使用 shuffle()
打亂數據
Step3: 使用 map()
函數進行預處理
Step4: 使用 batch()
函數設置 batch size
值
Step5: 根據需要 使用 repeat()
設置是否循環迭代數據集
代碼
import tensorflow as tf
from tensorflow import keras
def load_dataset():
# Step0 準備數據集, 可以是自己動手豐衣足食, 也可以從 tf.keras.datasets 加載需要的數據集(獲取到的是numpy數據)
# 這裏以 mnist 爲例
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
# Step1 使用 tf.data.Dataset.from_tensor_slices 進行加載
db_train = tf.data.Dataset.from_tensor_slices((x, y))
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# Step2 打亂數據
db_train.shuffle(1000)
db_test.shuffle(1000)
# Step3 預處理 (預處理函數在下面)
db_train.map(preprocess)
db_test.map(preprocess)
# Step4 設置 batch size 一次喂入64個數據
db_train.batch(64)
db_test.batch(64)
# Step5 設置迭代次數(迭代2次) test數據集不需要emmm
db_train.repeat(2)
return db_train, db_test
def preprocess(labels, images):
'''
最簡單的預處理函數:
轉numpy爲Tensor、分類問題需要處理label爲one_hot編碼、處理訓練數據
'''
# 把numpy數據轉爲Tensor
labels = tf.cast(labels, dtype=tf.int32)
# labels 轉爲one_hot編碼
labels = tf.one_hot(labels, depth=10)
# 順手歸一化
images = tf.cast(images, dtype=tf.float32) / 255
return labels, images
-
one_hot 編碼: 小姐姐給你解釋去 (我在使用自帶的fit函數進行訓練的時候,發現報錯維度不正確,原來是不需要one_hot編碼)
-
shuffle()函數的數值: 源碼鏈接, 內容我貼圖了
我找到一個比較好的解釋: 簡書真是好東西 -
我發現 自己的數據使用tf.data.Dataset.from_tensor_slices(x, y)加載時, 一定要x在前y在後。。。沒仔細看函數說明,否則會導致bug的emmm
-
使用了該函數之後, fit的時候是不支持
validation_split
這個參數提供的功能的~
總結
五個步驟很重要 比較簡單的方式加載數據 當然還有其他方法加載 之後再說叭
此外, 建議讀讀api tf.data.Dataset
裏好東西太多了~