TensorFlow雜記 - 生成和讀取TFRecord(一)

利用閒暇時間,通過SSD-TensorFlow和網絡資源上總結兩種生成和讀取TFRecord的方法,代碼測試通過。

軟件平臺:Pycarn 2019.2 + Tensorflow 1.13.1 + cuda 10.0 + cudnn 7.6.5

GPU: 1080Ti

 

不囉嗦了,直接上代碼,有問題可發送郵件 <[email protected]>, 或直接留言討論。

 

TFRecord生成:

import os
import tensorflow as tf 
from PIL import Image  # 注意Image,後面會用到
import xml.etree.ElementTree as ET

 
JpgFilePath='H:\\11_DataSet\\QD\\JPEGImages\\'
XmlFilePath='H:\\11_DataSet\\QD\\Annotations\\'
#文件路徑
filepath = 'H:\\11_DataSet\\QD'
writer= tf.python_io.TFRecordWriter("H:\\11_DataSet\\QD\\qd.tfrecord")

VOC_LABELS = {
    'B01': (0, 'B01'),
    'D01': (1, 'D01'),
    'G01': (2, 'G01'),
    'W01': (3, 'W01'),
    'W02': (4, 'W02'),
    'W03': (5, 'W03'),
    'W04': (6, 'W04'),
    'T01': (7, 'T01'),
    'R01': (8, 'R01'),
    'F01': (9, 'F01'),
    'S01': (10, 'S01'),
    'G02': (11, 'G02'),
    'G03': (12, 'G03'),
    'W05': (13, 'W05'),
    'I18': (14, 'I18'),
    'I01': (15, 'I01'),
    'I02': (16, 'I02'),
    'I03': (17, 'I03'),
    'I04': (18, 'I04'),
    'I99': (19, 'I99'),
}

for img_name in os.listdir(JpgFilePath):
    img_path=JpgFilePath+img_name  # 每一個圖片的地址

    # Jpeg
    img=Image.open(img_path)
    print(img_path)
    # img= img.resize((300, 300))
    image_data=img.tobytes()  # 將圖片轉化爲二進制格式

    # Xml
    xml_name = img_name[:-4]  # 文件名,不帶後綴
    xml_path = XmlFilePath+xml_name+'.xml'
    print(xml_path)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text), # shape是個list,shape[0]: height, shape[1]: width, shape[2]: depth
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    print(labels)

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
            'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
            'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
            'image/shape': tf.train.Feature(int64_list=tf.train.Int64List(value=shape)),
            'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
            'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
            'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
            'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
            'image/object/bbox/label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)),
            'image/object/bbox/label_text': tf.train.Feature(bytes_list=tf.train.BytesList(value=labels_text)),
            'image/object/bbox/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult)),
            'image/object/bbox/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
            'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
            'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))}))

    writer.write(example.SerializeToString())  # 序列化爲字符串

writer.close()

TFRecord讀取:

import os 
import tensorflow as tf 
from PIL import Image  

storepath='H:\\11_DataSet\\QD\\'

filename = 'H:\\11_DataSet\\QD\\qd.tfrecord'

filename_queue = tf.train.string_input_producer([filename])  # 生成一個queue隊列
 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                            'image/shape': tf.FixedLenFeature([3], tf.int64),
                                            'image/height': tf.FixedLenFeature([1], tf.int64),
                                            'image/width': tf.FixedLenFeature([1], tf.int64),
                                            'image/encoded': tf.FixedLenFeature((), tf.string),
                                             })  # 將image數據和label取出來

image = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.reshape(image, [1080, 1920, 3])
#image = tf.cast(img, tf.float32) * (1. / 255) - 0.5
shape = tf.cast(features['image/shape'], tf.int64)
height = tf.cast(features['image/height'], tf.int64)
width = tf.cast(features['image/width'], tf.int64)

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(10):
        print("loop %d\n" % i)
        example, s, h, w = sess.run([image, shape, height, width])  # 在會話中取出image和label
        # h, w = sess.run([height, width])  # 在會話中取出image和label
        print(s)
        print(h)
        print(w)

        img=Image.fromarray(example, 'RGB')  # 這裏Image是之前提到的
        img.save(storepath+str(i)+'.jpg') #  存下圖片

    coord.request_stop()
    coord.join(threads)


 

發佈了33 篇原創文章 · 獲贊 8 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章