使用tf.data.Dataset.from_tensor_slices五步加載數據集

前言:

最近在學習tf2
數據加載感覺蠻方便的
這裏記錄下使用 tf.data.Dataset.from_tensor_slices 進行加載數據集.
使用tf2做mnist(kaggle)的代碼

思路

Step0: 準備要加載的numpy數據
Step1: 使用 tf.data.Dataset.from_tensor_slices() 函數進行加載
Step2: 使用 shuffle() 打亂數據
Step3: 使用 map() 函數進行預處理
Step4: 使用 batch() 函數設置 batch size
Step5: 根據需要 使用 repeat() 設置是否循環迭代數據集

代碼

import tensorflow as tf
from tensorflow import keras

def load_dataset():
	# Step0 準備數據集, 可以是自己動手豐衣足食, 也可以從 tf.keras.datasets 加載需要的數據集(獲取到的是numpy數據) 
	# 這裏以 mnist 爲例
	(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
	
	# Step1 使用 tf.data.Dataset.from_tensor_slices 進行加載
	db_train = tf.data.Dataset.from_tensor_slices((x, y))
	db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
	
	# Step2 打亂數據
	db_train.shuffle(1000)
	db_test.shuffle(1000)
	
	# Step3 預處理 (預處理函數在下面)
	db_train.map(preprocess)
	db_test.map(preprocess)

	# Step4 設置 batch size 一次喂入64個數據
	db_train.batch(64)
	db_test.batch(64)

	# Step5 設置迭代次數(迭代2次) test數據集不需要emmm
	db_train.repeat(2)

	return db_train, db_test

def preprocess(labels, images):
	'''
	最簡單的預處理函數:
		轉numpy爲Tensor、分類問題需要處理label爲one_hot編碼、處理訓練數據
	'''
	# 把numpy數據轉爲Tensor
	labels = tf.cast(labels, dtype=tf.int32)
	# labels 轉爲one_hot編碼
	labels = tf.one_hot(labels, depth=10)
	# 順手歸一化
	images = tf.cast(images, dtype=tf.float32) / 255
	return labels, images
  1. one_hot 編碼: 小姐姐給你解釋去 (我在使用自帶的fit函數進行訓練的時候,發現報錯維度不正確,原來是不需要one_hot編碼)

  2. shuffle()函數的數值: 源碼鏈接, 內容我貼圖了
    函數定義源碼
    我找到一個比較好的解釋: 簡書真是好東西

  3. 我發現 自己的數據使用tf.data.Dataset.from_tensor_slices(x, y)加載時, 一定要x在前y在後。。。沒仔細看函數說明,否則會導致bug的emmm

  4. 使用了該函數之後, fit的時候是不支持 validation_split 這個參數提供的功能的~

總結

五個步驟很重要 比較簡單的方式加載數據 當然還有其他方法加載 之後再說叭
此外, 建議讀讀api tf.data.Dataset 裏好東西太多了~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章