之前我們對圖片數據的讀取是通過tensorflow的mnist類直接下載和加載mnist數據集。但是更多的時候,我們想通過本地的圖片進行訓練。
import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
dirPath = "F:/byxStudy/img/mnist0-9/"
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
#保存tfrecords
tfrecord_filename = 'F:\\byxStudy\\img\\mnist0_9_train.tfrecords'
fileSubDirList = os.listdir(dirPath)
with tf.python_io.TFRecordWriter(tfrecord_filename) as writer:
# 遍歷子文件夾
for fileSubDir in fileSubDirList:
fileSubDirSubDir = os.listdir(dirPath + fileSubDir + "/")
for filePath in fileSubDirSubDir:
image_value = tf.read_file(dirPath + fileSubDir + "/" + filePath)
img = tf.image.decode_jpeg(image_value, 1)
img = tf.image.resize_images(img, (28, 28), method=0)
pic2 = img.eval(session=sess)
image_raw = pic2.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'height': _int64_feature(pic2.shape[0]),
'width': _int64_feature(pic2.shape[1]),
'depth': _int64_feature(pic2.shape[2]),
'label': _int64_feature(int(fileSubDir)),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
import cv2
import os
import tensorflow as tf
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
dir_path = "F:/byxStudy/img/mnist0-9/"
tfrecord_filename = 'F:\\byxStudy\\img\\mnist0_9_train.tfrecords'
dir_list = os.listdir(dir_path)
with tf.python_io.TFRecordWriter(tfrecord_filename) as writer:
# 遍歷子文件夾
for sub_dir in dir_list:
sub_dir_list = os.listdir(dir_path + sub_dir + "/")
for filePath in sub_dir_list:
# 單通道的方式讀取圖片
# 中文目錄無法加載圖片
img = cv2.imread(dir_path + sub_dir + '/' + filePath, 0)
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC)
image_raw = img.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'height': _int64_feature(img.shape[0]),
'width': _int64_feature(img.shape[1]),
'depth': _int64_feature(1),
'label': _int64_feature(int(sub_dir)),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())