1 介紹
本文利用Tensorflow以CASIA-Webface爲例子讀取tfrecords數據數據。
2 導入包
import mxnet as mx
import argparse
import PIL.Image
import io
import numpy as np
import cv2
import tensorflow as tf
import os
3 主函數
if __name__ == '__main__': args = parse_args() config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # training datasets api config tfrecords_f = os.path.join(args.tfrecords_file_path, 'tran.tfrecords') dataset = tf.data.TFRecordDataset(tfrecords_f) dataset = dataset.map(parse_function) #隊列緩衝區長度30000,並且打亂順序 dataset = dataset.shuffle(buffer_size=30000) #這樣每次迭代就是32張圖片 dataset = dataset.batch(32) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # begin iteration for i in range(1000): sess.run(iterator.initializer) while True: try: images, labels = sess.run(next_element) plt.imshow(images[0]) fig = plt.gcf() plt.show() fig.savefig('test.jpg') except tf.errors.OutOfRangeError: print("End of dataset")
4 解析函數
def parse_function(example_proto): features = {'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)} features = tf.parse_single_example(example_proto, features) # You can do more image distortion here for training data img = tf.image.decode_jpeg(features['image_raw']) img = tf.reshape(img, shape=(112, 112, 3)) r, g, b = tf.split(img, num_or_size_splits=3, axis=-1) img = tf.concat([b, g, r], axis=-1) img = tf.cast(img, dtype=tf.float32) #歸一化 img = tf.subtract(img, 127.5) img = tf.multiply(img, 0.0078125) #一些圖像增強操作 img = tf.image.random_flip_left_right(img) label = tf.cast(features['label'], tf.int64) return img, label
5 關於數據
本文例子中用到CASIA-Webface的tfrecords數據
歡迎大家關注和轉發本公衆號:facefinetune