將mnist數據轉成原始圖片數據再轉成TFRecord格式

1、將mnist數據轉成原始圖片數據

def convert_mnist_img(data, save_path):
    for i in range(data.images.shape[0]):
        img = data.images[i].reshape([28, 28, 1])
        img = (img * 255).astype(np.uint8)
        label = data.labels[i]
        # cv2.imshow('image', img)
        # cv2.waitKey(500)
        filename = save_path + '/{}_{}.jpg'.format(label, i)
        cv2.imwrite(filename, img)

if __name__ == '__main__':
    mnist = input_data.read_data_sets('./data', source_url='http://yann.lecun.com/exdb/mnist/')
    convert_mnist_img(mnist.train, 'img_train')
    print('convert training data to image complete')
    convert_mnist_img(mnist.test, 'img_test')
    print('convert test data to image complete')
    convert_mnist_img(mnist.validation, 'img_validation')
    print('convert validation data to image complete')

這樣就可以把訓練、驗證、測試集的圖片分別保存下來:

2、將圖片數據轉成TFRecord格式文件

def convert_img_tfrecords(data_path, record_dir):
    writer = tf.python_io.TFRecordWriter(record_dir)
    for file in os.listdir(data_path):
        img = cv2.imread(os.path.join(data_path, file), cv2.IMREAD_GRAYSCALE)
        img_raw = img.tobytes()
        label = int(file.split('_')[0])
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())
    writer.close()

if __name__ == '__main__':
    convert_img_tfrecords('./img_validation', 'validation_img.tfrecords')
    print('convert validation image to tfrecords complete')
    convert_img_tfrecords('./img_test', 'test_img.tfrecords')
    print('convert test image to tfrecords complete')
    convert_img_tfrecords('./img_train', 'train_img.tfrecords')
    print('convert train image to tfrecords complete')

針對訓練集、驗證集、測試集生成對應的三個TFRecord格式文件。

3、解析TFRecord格式文件

def read_record(record_dir):
    for serialized_exam in tf.python_io.tf_record_iterator(record_dir):
        example = tf.train.Example()
        example.ParseFromString(serialized_exam)

        image = example.features.feature['img_raw'].bytes_list.value[0]
        label = example.features.feature['label'].int64_list.value[0]
        image = np.fromstring(image, dtype=np.uint8)
        image = image.reshape([28, 28, 1])

        cv2.imshow('image', image)
        cv2.waitKey(1000)

        print(image.shape, label)
    cv2.destroyAllWindows()

可以解析TFRecord文件,查看是否正確。

真正訓練的時候,可以結合tf.train.string_input_producer和tf.train.Coordinator()使用,利用隊列生成批量數據,以供訓練。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章