1. Keras/Tensorflow 2.0 自定義數據集 Dataset

在學習Tensorflow的過程中,發現大多數教程都是基於現有的數據集進行訓練、優化。

例如:MNIST識別教程,一個

(x_train, y_train), (x_test, y_test) = mnist.load_data()

即可獲得訓練、測試數據集。

而在解決實際問題時,我們經常面對的是採集到的原始圖片信息,這些圖片保存在硬盤當中,當模型搭建好以後開始把數據從硬盤加載到內存,然後計算。然而加載數是需要時間的,如果圖片數據比較大,那麼無疑浪費了很多數據讀取的時間。

我們期望的是:將這些圖片信息製作成帶標籤的數據集,並能方便的shuffle、batch,快速、高效的提供給模型進行訓練。

本文以一個花卉識別的例子來展示如何利用Tensorflow的pipeline和緩存技術來方便、快捷的實現一個自定義數據集。

1. 獲取圖片數據:

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
import tensorflow as tf
import pathlib
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)

執行完畢後會在~/.keras/datasets/下保存包含5種花卉圖片的文件夾:

2. 查看圖片:

快速瀏覽幾張圖片,以知道你在處理什麼:

data_root = pathlib.Path(data_root_orig)
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
print('Image count: ', image_count)

plt.figure('image show')
for n in range(3):
	image_path = random.choice(all_image_paths)
	label = image_path.split('/')[-2]
	image = Image.open(image_path)
	print(image.size)

	plt.subplot(1, 3, n+1)
	plt.title(label)
	plt.imshow(image)
plt.show()

3. 確定每張圖片的標籤:

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print('label_names: ', label_names) # ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

label_to_index = dict((name, index) for index, name in enumerate(label_names))
print('label_to_index: ', label_to_index) # {'sunflowers': 3, 'daisy': 0, 'roses': 2, 'tulips': 4, 'dandelion': 1}

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
print("First 10 labels indices: ", all_image_labels[:10]) # [2, 2, 2, 2, 3, 4, 1, 1, 3, 2]

3. 讀取和格式化圖片:

主要工作是通過tf.io.read_file將圖片路徑名轉化爲圖片張量,並將每個像素值轉換爲[0 - 1]的範圍(方便訓練)。

def preprocess_image(img_raw):
	img_tensor = tf.image.decode_jpeg(contents=img_raw, channels=3) # can be used for plt.imshow(img_tensor)
	img_final = tf.image.resize(images=img_tensor, size=[192, 192])
	img_final /= 255.0 # normalize to [0,1] range
	return img_final

def load_and_preprocess_image(path):
	img_raw = tf.io.read_file(path) # can't be used for plt.imshow(img_raw)
	return preprocess_image(img_raw)

def load_and_preprocess_from_path_label(path, label):
	return load_and_preprocess_image(path), label

def load_and_preprocess_image(path):
	img_raw = tf.io.read_file(path)

	return preprocess_image(img_raw)

4. 構建Dataset:

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
image_label_ds = ds.map(load_and_preprocess_from_path_label)

all_image_paths和all_image_labels這兩個list中,每張圖片和其標籤是一一對應的,因此可以打包爲一個(圖片 - 標籤)組。

tf.data.Dataset.from_tensor_slices返回的ds具有很多實用的方法用來操作數據集,例如:shuffle、batch、repeat等,方便後來加載進模型進行訓練。

5. 加載圖片數據集:

爲了高效的從硬盤加載進內存,我們採用了Tensorflow的緩存技術,並且在圖片數據遠大於內存RAM大小時,仍然可以獲得較高的性能。

BATCH_SIZE = 32

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.shuffle(buffer_size=image_count, reshuffle_each_iteration=True).repeat()
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=1)

至此,Dataset數據集構建完畢,可以用來高效的訓練模型。

下一篇將以一個遷移學習的例子展示如何利用Dataset來訓練模型。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章