TensorFlow雜記 - 生成和讀取TFRecord(二)

與上一篇不同的FRecord生成和讀取方法,抽取自SSD-TensorFlow,並做一定的修改。

使用Slim生成和讀取TFRecord。

軟件平臺:Pycarn 2019.2 + Tensorflow 1.13.1 + cuda 10.0 + cudnn 7.6.5

GPU: 1080Ti

 

上代碼。

dataset_common.py

 

"""Provides data for the Pascal VOC Dataset (images + annotations).
"""
import os

import tensorflow as tf
import dataset_utils

slim = tf.contrib.slim

VOC_LABELS = {
    'B01': (0, 'B01'),
    'D01': (1, 'D01'),
    'G01': (2, 'G01'),
    'W01': (3, 'W01'),
    'W02': (4, 'W02'),
    'W03': (5, 'W03'),
    'W04': (6, 'W04'),
    'T01': (7, 'T01'),
    'R01': (8, 'R01'),
    'F01': (9, 'F01'),
    'S01': (10, 'S01'),
    'G02': (11, 'G02'),
    'G03': (12, 'G03'),
    'W05': (13, 'W05'),
    'I18': (14, 'I18'),
    'I01': (15, 'I01'),
    'I02': (16, 'I02'),
    'I03': (17, 'I03'),
    'I04': (18, 'I04'),
    'I99': (19, 'I99'),
}


def get_split(split_name, dataset_dir, file_pattern, reader,
              split_to_sizes, items_to_descriptions, num_classes):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      split_name: A train/test split name.
      dataset_dir: The base directory of the dataset sources.
      file_pattern: The file pattern to use when matching the dataset sources.
        It is assumed that the pattern contains a '%s' string so that the split
        name can be inserted.
      reader: The TensorFlow reader type.

    Returns:
      A `Dataset` namedtuple.

    Raises:
        ValueError: if `split_name` is not a valid train/test split.
    """
    if split_name not in split_to_sizes:
        raise ValueError('split name %s was not recognized.' % split_name)
    file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader
    # Features in Pascal VOC TFRecords.
    keys_to_features = { # 解析TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = { # 解碼二進制數據條目
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir): #查看指定路徑下,是否有labels.txt文件
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    # else:
    #     labels_to_names = create_readable_names_for_imagenet_labels()
    #     dataset_utils.write_label_file(labels_to_names, dataset_dir)

    return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=reader,
            decoder=decoder,
            num_samples=split_to_sizes[split_name],
            items_to_descriptions=items_to_descriptions,
            num_classes=num_classes,
            labels_to_names=labels_to_names)

 

dataset_utils.py

"""Contains utilities for downloading and converting datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import tarfile # 管理tar壓縮包

from six.moves import urllib # six是用來兼容python 2 和 3的,我猜名字就是用的2和3的最小公倍數。
                             # six.moves 是用來處理那些在2 和 3裏面函數的位置有變化的,直接用six.moves就可以屏蔽掉這些變化
import tensorflow as tf

LABELS_FILENAME = 'labels.txt'


def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def image_to_tfexample(image_data, image_format, height, width, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
    }))


def download_and_uncompress_tarball(tarball_url, dataset_dir):
    """Downloads the `tarball_url` and uncompresses it locally.

    Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
    """
    filename = tarball_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    def _progress(count, block_size, total_size):
        sys.stdout.write('\r>> Downloading %s %.1f%%' % (
            filename, float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)


def write_label_file(labels_to_class_names, dataset_dir,
                     filename=LABELS_FILENAME):
    """Writes a file with the list of class names.

    Args:
    labels_to_class_names: A map of (integer) labels to class names.
    dataset_dir: The directory in which the labels file should be written.
    filename: The filename where the class names are written.
    """
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'w') as f:
        for label in labels_to_class_names:
            class_name = labels_to_class_names[label]
            f.write('%d:%s\n' % (label, class_name))


def has_labels(dataset_dir, filename=LABELS_FILENAME):
    """Specifies whether or not the dataset directory contains a label map file.

    Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.

    Returns:
    `True` if the labels file exists and `False` otherwise.
    """
    return tf.gfile.Exists(os.path.join(dataset_dir, filename))


def read_label_file(dataset_dir, filename=LABELS_FILENAME):
    """Reads the labels file and returns a mapping from ID to class name.

    Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.

    Returns:
    A map from a label (integer) to class name.
    """
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'rb') as f:
        lines = f.read()
    lines = lines.split(b'\n')
    lines = filter(None, lines)

    labels_to_class_names = {}
    for line in lines:
        index = line.index(b':')
        labels_to_class_names[int(line[:index])] = line[index+1:]
    return labels_to_class_names

 

transfer_to_tfrecords.py

"""Converts data to TFRecords file format with Example protos.

The raw Pascal VOC data set is expected to reside in JPEG files located in the
directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be
stored in the 'Annotation directory'

This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of 1024 and 128 TFRecord files, respectively.

Each validation TFRecord file contains ~500 records. Each training TFREcord
file contains ~1000 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:

    image/encoded: string containing JPEG encoded image in RGB colorspace
    image/height: integer, image height in pixels
    image/width: integer, image width in pixels
    image/channels: integer, specifying the number of channels, always 3
    image/format: string, specifying the format, always'JPEG'


    image/object/bbox/xmin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/xmax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/label: list of integer specifying the classification index.
    image/object/bbox/label_text: list of string descriptions.

"""

# 1. 按照VOC格式存儲圖片數據和標註數據;
# 2. 在run函數中指定輸入、輸出路徑和輸出tfrecord數據庫的名字;


import os
import sys
import random

import numpy as np
import tensorflow as tf

import xml.etree.ElementTree as ET

from PIL import Image  #注意Image,後面會用到

from dataset_utils import int64_feature, float_feature, bytes_feature
from dataset_common import VOC_LABELS

# Original dataset organisation.
DIRECTORY_ANNOTATIONS = 'Annotations\\'
DIRECTORY_IMAGES = 'JPEGImages\\'

# TFRecords convertion parameters.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 100


"""
獲取一張圖片數據及其對應的尺寸,標註信息等
"""
def _process_image(directory, name):
    """Process a image and annotation file.

    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
    Returns:
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    """

    filename = directory + DIRECTORY_IMAGES + name + '.jpg'  # 完整的路徑,帶後綴

    """
        使用tf.gfile.FastGFile和tf.gfile.GFile
    """
    # # Read the image file.
    # image_data = tf.gfile.FastGFile(filename, 'rb').read()
    image_data = tf.gfile.GFile(filename, 'rb').read()

    # Read the XML annotation file.
    xmlfilename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    print(xmlfilename)

    tree = ET.parse(xmlfilename)
    root = tree.getroot()

    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text), # shape是個list,shape[0]: height, shape[1]: width, shape[2]: depth
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    return image_data, shape, bboxes, labels, labels_text, difficult, truncated


"""
    將一張圖片的相關信息轉換爲example
"""
def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
                        difficult, truncated):
    """將一張圖片轉換成example

    Args:
      image_data: Jpeg圖片數據;
      labels: ground truth list;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    """
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    # image_format = b'JPEG'
    image_format = b'jpeg'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]), # 在dataset_utils.py中定義
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/object/bbox/difficult': int64_feature(difficult),
            'image/object/bbox/truncated': int64_feature(truncated),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
    return example


"""
    將一張圖片對應的example保存到tfrecord
    一個example是
"""
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    """加載一張圖片及其標註信息至example
       並將當前example寫入tfrecord

    Args:
      dataset_dir: 數據集路徑;
      name: 圖片名;
      tfrecord_writer: TFRecord writer.
    """
    # 一張圖片的數據
    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
        _process_image(dataset_dir, name)
    example = _convert_to_example(image_data, labels, labels_text,
                                  bboxes, shape, difficult, truncated)
    tfrecord_writer.write(example.SerializeToString())


def _get_output_filename(output_dir, name, idx):
    return '%s%s_%03d.tfrecord' % (output_dir, name, idx)


def run(dataset_dir, output_dir, name='train', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: 數據集路徑.
      output_dir: TFRecord輸出路徑.
    """
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    # Dataset filenames, and shuffling.
    path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS) # xml文件路徑
    print("xml path: ", path)
    filenames = sorted(os.listdir(path))  # 返回一個list,包含了當前文件夾下所有的文件名
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    # print(filenames)

    # Process dataset files.
    i = 0
    fidx = 0 # 標識TFR文件索引
    while i < len(filenames):  # 一張圖片一張圖片的處理
        # Open new TFRecord file.
        tf_filename = _get_output_filename(output_dir, name, fidx)  #TFR文件保存路徑
        print("tf_filename: ", tf_filename)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES: # 每個TFR保存100張圖片的sample數據
                sys.stdout.write('\r>> Converting image %d/%d \n' % (i+1, len(filenames)))
                sys.stdout.flush()

                filename = filenames[i]
                img_name = filename[:-4] # 文件名,不帶後綴
                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer) # 對一張圖片執行TFR相應操作
                i += 1
                j += 1
            fidx += 1

    # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    print('\nFinished converting the Pascal VOC dataset!')



if __name__ == "__main__":
    # execute only if run as a script
    run("H:\\11_DataSet\\QD\\", "H:\\11_DataSet\\QD\\")

 

read_from_tfrecords.py

# coding=utf-8
import tensorflow as tf
from PIL import Image

slim = tf.contrib.slim

# 指定TFRecord路徑
tfrecords_filename = 'H:\\11_DataSet\\QD\\train_000.tfrecord'


"""
    使用Slim的方法從TFrecord文件中讀取
"""
def read_record_file():



    keys_to_features = { # 解析TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([], tf.int64),
        'image/width': tf.FixedLenFeature([], tf.int64),
        'image/channels': tf.FixedLenFeature([], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = { # 解碼二進制數據條目
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format', channels=3),
        'height': slim.tfexample_decoder.Tensor('image/height'),
        'width': slim.tfexample_decoder.Tensor('image/width'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }



    # 定義解碼器,進行解碼
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    # 定義dataset,該對象定義了數據集的文件位置,解碼方式等元信息
    dataset = slim.dataset.Dataset(
        data_sources=tfrecords_filename,
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=100,  # 訓練數據的總數
        items_to_descriptions=None,
        num_classes=20,
    )
    # 使用provider對象根據dataset信息讀取數據
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=1,
        common_queue_capacity=20,
        common_queue_min=10)

    # 獲取數據
    [image, shape, h, w] = provider.get(['image', 'shape', 'height', 'width'])
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(10):
            img, h_, w_, shape_ = sess.run([image, h, w, shape])
            img = tf.reshape(img, [h_, w_, 3])
            print(img.shape)
            print(img)
            print("h = %d" % h_)
            print("w = %d" % w_)

            img=Image.fromarray(img.eval(), 'RGB')       # 這裏將narray轉爲Image類,Image轉narray:a=np.array(img)
            img.save('./'+str(i)+'.jpg')                 # 保存圖片

        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # create_record_file()
    read_record_file()

 

發佈了33 篇原創文章 · 獲贊 8 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章