1 介绍
本文利用Tensorflow以CASIA-Webface为例子读取tfrecords数据数据。
2 导入包
import mxnet as mx
import argparse
import PIL.Image
import io
import numpy as np
import cv2
import tensorflow as tf
import os
3 主函数
if __name__ == '__main__': args = parse_args() config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) # training datasets api config tfrecords_f = os.path.join(args.tfrecords_file_path, 'tran.tfrecords') dataset = tf.data.TFRecordDataset(tfrecords_f) dataset = dataset.map(parse_function) #队列缓冲区长度30000,并且打乱顺序 dataset = dataset.shuffle(buffer_size=30000) #这样每次迭代就是32张图片 dataset = dataset.batch(32) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # begin iteration for i in range(1000): sess.run(iterator.initializer) while True: try: images, labels = sess.run(next_element) plt.imshow(images[0]) fig = plt.gcf() plt.show() fig.savefig('test.jpg') except tf.errors.OutOfRangeError: print("End of dataset")
4 解析函数
def parse_function(example_proto): features = {'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)} features = tf.parse_single_example(example_proto, features) # You can do more image distortion here for training data img = tf.image.decode_jpeg(features['image_raw']) img = tf.reshape(img, shape=(112, 112, 3)) r, g, b = tf.split(img, num_or_size_splits=3, axis=-1) img = tf.concat([b, g, r], axis=-1) img = tf.cast(img, dtype=tf.float32) #归一化 img = tf.subtract(img, 127.5) img = tf.multiply(img, 0.0078125) #一些图像增强操作 img = tf.image.random_flip_left_right(img) label = tf.cast(features['label'], tf.int64) return img, label
5 关于数据
本文例子中用到CASIA-Webface的tfrecords数据
欢迎大家关注和转发本公众号:facefinetune