利用閒暇時間,通過SSD-TensorFlow和網絡資源上總結兩種生成和讀取TFRecord的方法,代碼測試通過。
軟件平臺:Pycarn 2019.2 + Tensorflow 1.13.1 + cuda 10.0 + cudnn 7.6.5
GPU: 1080Ti
不囉嗦了,直接上代碼,有問題可發送郵件 <[email protected]>, 或直接留言討論。
TFRecord生成:
import os
import tensorflow as tf
from PIL import Image # 注意Image,後面會用到
import xml.etree.ElementTree as ET
JpgFilePath='H:\\11_DataSet\\QD\\JPEGImages\\'
XmlFilePath='H:\\11_DataSet\\QD\\Annotations\\'
#文件路徑
filepath = 'H:\\11_DataSet\\QD'
writer= tf.python_io.TFRecordWriter("H:\\11_DataSet\\QD\\qd.tfrecord")
VOC_LABELS = {
'B01': (0, 'B01'),
'D01': (1, 'D01'),
'G01': (2, 'G01'),
'W01': (3, 'W01'),
'W02': (4, 'W02'),
'W03': (5, 'W03'),
'W04': (6, 'W04'),
'T01': (7, 'T01'),
'R01': (8, 'R01'),
'F01': (9, 'F01'),
'S01': (10, 'S01'),
'G02': (11, 'G02'),
'G03': (12, 'G03'),
'W05': (13, 'W05'),
'I18': (14, 'I18'),
'I01': (15, 'I01'),
'I02': (16, 'I02'),
'I03': (17, 'I03'),
'I04': (18, 'I04'),
'I99': (19, 'I99'),
}
for img_name in os.listdir(JpgFilePath):
img_path=JpgFilePath+img_name # 每一個圖片的地址
# Jpeg
img=Image.open(img_path)
print(img_path)
# img= img.resize((300, 300))
image_data=img.tobytes() # 將圖片轉化爲二進制格式
# Xml
xml_name = img_name[:-4] # 文件名,不帶後綴
xml_path = XmlFilePath+xml_name+'.xml'
print(xml_path)
tree = ET.parse(xml_path)
root = tree.getroot()
# Image shape.
size = root.find('size')
shape = [int(size.find('height').text), # shape是個list,shape[0]: height, shape[1]: width, shape[2]: depth
int(size.find('width').text),
int(size.find('depth').text)]
# Find annotations.
bboxes = []
labels = []
labels_text = []
difficult = []
truncated = []
for obj in root.findall('object'):
label = obj.find('name').text
labels.append(int(VOC_LABELS[label][0]))
labels_text.append(label.encode('ascii'))
if obj.find('difficult'):
difficult.append(int(obj.find('difficult').text))
else:
difficult.append(0)
if obj.find('truncated'):
truncated.append(int(obj.find('truncated').text))
else:
truncated.append(0)
bbox = obj.find('bndbox')
bboxes.append((float(bbox.find('ymin').text) / shape[0],
float(bbox.find('xmin').text) / shape[1],
float(bbox.find('ymax').text) / shape[0],
float(bbox.find('xmax').text) / shape[1]
))
print(labels)
xmin = []
ymin = []
xmax = []
ymax = []
for b in bboxes:
assert len(b) == 4
# pylint: disable=expression-not-assigned
[l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
# pylint: enable=expression-not-assigned
image_format = b'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
'image/shape': tf.train.Feature(int64_list=tf.train.Int64List(value=shape)),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
'image/object/bbox/label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)),
'image/object/bbox/label_text': tf.train.Feature(bytes_list=tf.train.BytesList(value=labels_text)),
'image/object/bbox/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult)),
'image/object/bbox/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))}))
writer.write(example.SerializeToString()) # 序列化爲字符串
writer.close()
TFRecord讀取:
import os
import tensorflow as tf
from PIL import Image
storepath='H:\\11_DataSet\\QD\\'
filename = 'H:\\11_DataSet\\QD\\qd.tfrecord'
filename_queue = tf.train.string_input_producer([filename]) # 生成一個queue隊列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'image/shape': tf.FixedLenFeature([3], tf.int64),
'image/height': tf.FixedLenFeature([1], tf.int64),
'image/width': tf.FixedLenFeature([1], tf.int64),
'image/encoded': tf.FixedLenFeature((), tf.string),
}) # 將image數據和label取出來
image = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.reshape(image, [1080, 1920, 3])
#image = tf.cast(img, tf.float32) * (1. / 255) - 0.5
shape = tf.cast(features['image/shape'], tf.int64)
height = tf.cast(features['image/height'], tf.int64)
width = tf.cast(features['image/width'], tf.int64)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print("loop %d\n" % i)
example, s, h, w = sess.run([image, shape, height, width]) # 在會話中取出image和label
# h, w = sess.run([height, width]) # 在會話中取出image和label
print(s)
print(h)
print(w)
img=Image.fromarray(example, 'RGB') # 這裏Image是之前提到的
img.save(storepath+str(i)+'.jpg') # 存下圖片
coord.request_stop()
coord.join(threads)