Tensorflow object detection API源碼分析之如何讀數據

 

 

核心代碼:

object_detection/train.py

關於TFRecord文件的位置一般在samples/configs/*.config中

舉例:samples/configs/ssd_mobilenet_v1_coco.config

train_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}

utils/config_util.py從config文件中讀取各個配置

def create_configs_from_pipeline_proto(pipeline_config):
  """Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig.

  Args:
    pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_configs`. Value are
      the corresponding config objects or list of config objects (only for
      eval_input_configs).
  """
  configs = {}
  configs["model"] = pipeline_config.model
  configs["train_config"] = pipeline_config.train_config
  configs["train_input_config"] = pipeline_config.train_input_reader
  configs["eval_config"] = pipeline_config.eval_config
  configs["eval_input_configs"] = pipeline_config.eval_input_reader
  # Keeps eval_input_config only for backwards compatibility. All clients should
  # read eval_input_configs instead.
  if configs["eval_input_configs"]:
    configs["eval_input_config"] = configs["eval_input_configs"][0]
  if pipeline_config.HasField("graph_rewriter"):
    configs["graph_rewriter_config"] = pipeline_config.graph_rewriter

  return configs

train.py讀取數據

  input_config = configs['train_input_config']
  def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()
  #tf.data.Iterator,一個迭代器
  create_input_dict_fn = functools.partial(get_next, input_config)

builders/dataset_builder.py

整個流程

  1. 初始化需要解碼的字段,以及解碼對應字段的 handler
  2. 調用 tf.data.TFRecordDataset 從 config.input_path 讀數據,調用 process_fn 解碼數據,預提取input_reader_config.prefetch_size 條數據
  3. 對數據集應用 tf.contrib.data.padded_batch_and_drop_remainder,如果不夠一個 batch_size 就丟棄該部分數據
  4. 返回一個迭代器

core/standard_fields.py

class InputDataFields(object):
  """Names for the input tensors.

  Holds the standard data field names to use for identifying input tensors. This
  should be used by the decoder to identify keys for the returned tensor_dict
  containing input tensors. And it should be used by the model to identify the
  tensors it needs.

  Attributes:
    image: image.
    image_additional_channels: additional channels.
    original_image: image in the original input size.
    key: unique key corresponding to image.
    source_id: source of the original image.
    filename: original filename of the dataset (without common path).
    groundtruth_image_classes: image-level class labels.
    groundtruth_boxes: coordinates of the ground truth boxes in the image.
    groundtruth_classes: box-level class labels.
    groundtruth_label_types: box-level label types (e.g. explicit negative).
    groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead]
      is the groundtruth a single object or a crowd.
    groundtruth_area: area of a groundtruth segment.
    groundtruth_difficult: is a `difficult` object
    groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the
      same class, forming a connected group, where instances are heavily
      occluding each other.
    proposal_boxes: coordinates of object proposal boxes.
    proposal_objectness: objectness score of each proposal.
    groundtruth_instance_masks: ground truth instance masks.
    groundtruth_instance_boundaries: ground truth instance boundaries.
    groundtruth_instance_classes: instance mask-level class labels.
    groundtruth_keypoints: ground truth keypoints.
    groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
    groundtruth_label_scores: groundtruth label scores.
    groundtruth_weights: groundtruth weight factor for bounding boxes.
    num_groundtruth_boxes: number of groundtruth boxes.
    true_image_shapes: true shapes of images in the resized images, as resized
      images can be padded with zeros.
    verified_labels: list of human-verified image-level labels (note, that a
      label can be verified both as positive and negative).
    multiclass_scores: the label score per class for each box.
  """
  image = 'image'
  image_additional_channels = 'image_additional_channels'
  original_image = 'original_image'
  key = 'key'
  source_id = 'source_id'
  filename = 'filename'
  groundtruth_image_classes = 'groundtruth_image_classes'
  groundtruth_boxes = 'groundtruth_boxes'
  groundtruth_classes = 'groundtruth_classes'
  groundtruth_label_types = 'groundtruth_label_types'
  groundtruth_is_crowd = 'groundtruth_is_crowd'
  groundtruth_area = 'groundtruth_area'
  groundtruth_difficult = 'groundtruth_difficult'
  groundtruth_group_of = 'groundtruth_group_of'
  proposal_boxes = 'proposal_boxes'
  proposal_objectness = 'proposal_objectness'
  groundtruth_instance_masks = 'groundtruth_instance_masks'
  groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
  groundtruth_instance_classes = 'groundtruth_instance_classes'
  groundtruth_keypoints = 'groundtruth_keypoints'
  groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
  groundtruth_label_scores = 'groundtruth_label_scores'
  groundtruth_weights = 'groundtruth_weights'
  num_groundtruth_boxes = 'num_groundtruth_boxes'
  true_image_shape = 'true_image_shape'
  verified_labels = 'verified_labels'
  multiclass_scores = 'multiclass_scores'

object_detection/builders/dataset_builder.py 

# 返回 dataset.output_shapes 各個 key 對應的 shape
def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
                        spatial_image_shape=None):
  """Returns shapes to pad dataset tensors to before batching.

  Args:
    dataset: tf.data.Dataset object.
    max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
      padding.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image.

  Returns:
    A dictionary keyed by fields.InputDataFields containing padding shapes for
    tensors in the dataset.

  Raises:
    ValueError: If groundtruth classes is neither rank 1 nor rank 2.
  """

  if not spatial_image_shape or spatial_image_shape == [-1, -1]:
    height, width = None, None
  else:
    height, width = spatial_image_shape  # pylint: disable=unpacking-non-sequence

  padding_shapes = {
      fields.InputDataFields.image: [height, width, 3],
      fields.InputDataFields.source_id: [],
      fields.InputDataFields.filename: [],
      fields.InputDataFields.key: [],
      fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
      fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
      fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
                                                          width],
      fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
      fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
      fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
      fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
      fields.InputDataFields.groundtruth_area: [max_num_boxes],
      fields.InputDataFields.groundtruth_weights: [max_num_boxes],
      fields.InputDataFields.num_groundtruth_boxes: [],
      fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
      fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
      fields.InputDataFields.true_image_shape: [3],
      fields.InputDataFields.multiclass_scores: [
          max_num_boxes, num_classes + 1 if num_classes is not None else None],
  }
  # Determine whether groundtruth_classes are integers or one-hot encodings, and
  # apply batching appropriately.
  classes_shape = dataset.output_shapes[
      fields.InputDataFields.groundtruth_classes]
  if len(classes_shape) == 1:  # Class integers.
    padding_shapes[fields.InputDataFields.groundtruth_classes] = [max_num_boxes]
  elif len(classes_shape) == 2:  # One-hot or k-hot encoding.
    padding_shapes[fields.InputDataFields.groundtruth_classes] = [
        max_num_boxes, num_classes]
  else:
    raise ValueError('Groundtruth classes must be a rank 1 tensor (classes) or '
                     'rank 2 tensor (one-hot encodings)')

  if fields.InputDataFields.original_image in dataset.output_shapes:
    padding_shapes[fields.InputDataFields.original_image] = [None, None, 3]
  if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
    tensor_shape = dataset.output_shapes[fields.InputDataFields.
                                         groundtruth_keypoints]
    padding_shape = [max_num_boxes, tensor_shape[1].value,
                     tensor_shape[2].value]
    padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
  if (fields.InputDataFields.groundtruth_keypoint_visibilities
      in dataset.output_shapes):
    tensor_shape = dataset.output_shapes[fields.InputDataFields.
                                         groundtruth_keypoint_visibilities]
    padding_shape = [max_num_boxes, tensor_shape[1].value]
    padding_shapes[fields.InputDataFields.
                   groundtruth_keypoint_visibilities] = padding_shape
  return {tensor_key: padding_shapes[tensor_key]
          for tensor_key, _ in dataset.output_shapes.items()}

def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None,
          num_additional_channels=0):
  """Builds a tf.data.Dataset.

  Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
  records. Applies a padded batch to the resulting dataset.

  Args:
    input_reader_config: A input_reader_pb2.InputReader object.
    transform_input_data_fn: Function to apply to all records, or None if
      no extra decoding is required.
    batch_size: Batch size. If None, batching is not performed.
    max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
      padding. If None, will use a dynamic shape.
    num_classes: Number of classes in the dataset needed to compute shapes for
      padding. If None, will use a dynamic shape.
    spatial_image_shape: A list of two integers of the form [height, width]
      containing expected spatial shape of the image after applying
      transform_input_data_fn. If None, will use dynamic shapes.
    num_additional_channels: Number of additional channels to use in the input.

  Returns:
    A tf.data.Dataset based on the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  '''
  train_input_reader: {
  tf_record_input_reader {
    input_path: "data/wider_udlr.record"
  }
  label_map_path: "data/udlr_face.pbtxt"
}
  '''
  #input_reader_config必須是input_reader_pb2.InputReader類型對象
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    #input_path,讀取tfrecord文件地址
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')

    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      #讀取label地址
      label_map_proto_file = input_reader_config.label_map_path
    #初始化需要解碼的字段,以及解碼對應字段的 handler
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file,
        use_display_name=input_reader_config.use_display_name,
        num_additional_channels=num_additional_channels)

    # 調用 tf.data.TFRecordDataset 從 config.input_path 讀數據,調用 process_fn 解碼
    # 數據,預提取 input_reader_config.prefetch_size 條數據
    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    dataset = dataset_util.read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        process_fn, config.input_path[:], input_reader_config)

    if batch_size:
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')
def make_initializable_iterator(dataset):
  """Creates an iterator, and initializes tables.

  This is useful in cases where make_one_shot_iterator wouldn't work because
  the graph contains a hash table that needs to be initialized.

  Args:
    dataset: A `tf.data.Dataset` object.

  Returns:
    A `tf.data.Iterator`.
  """
  iterator = dataset.make_initializable_iterator()
  tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
  return iterator

# 調用 file_read_func 從 input_files 中讀數據
# 調用 decode_func 將讀到的數據解碼
# 預提取 config.prefetch_size
def read_dataset(file_read_func, decode_func, input_files, config):
  """Reads a dataset, and handles repetition and shuffling.

  Args:
    file_read_func: Function to use in tf.data.Dataset.interleave, to read
      every individual file into a tf.data.Dataset.
    decode_func: Function to apply to all records.
    input_files: A list of file paths to read.
    config: A input_reader_builder.InputReader object.

  Returns:
    A tf.data.Dataset based on config.
  """
  # Shard, shuffle, and read files.
  filenames = tf.gfile.Glob(input_files)
  num_readers = config.num_readers
  if num_readers > len(filenames):
    num_readers = len(filenames)
    tf.logging.warning('num_readers has been reduced to %d to match input file '
                       'shards.' % num_readers)
  filename_dataset = tf.data.Dataset.from_tensor_slices(tf.unstack(filenames))
  if config.shuffle:
    filename_dataset = filename_dataset.shuffle(
        config.filenames_shuffle_buffer_size)
  elif num_readers > 1:
    tf.logging.warning('`shuffle` is false, but the input data stream is '
                       'still slightly shuffled since `num_readers` > 1.')

  filename_dataset = filename_dataset.repeat(config.num_epochs or None)

  records_dataset = filename_dataset.apply(
      tf.contrib.data.parallel_interleave(
          file_read_func,
          cycle_length=num_readers,
          block_length=config.read_block_length,
          sloppy=config.shuffle))
  if config.shuffle:
    records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
  tensor_dataset = records_dataset.map(
      decode_func, num_parallel_calls=config.num_parallel_map_calls)
  return tensor_dataset.prefetch(config.prefetch_size)

protos/input_reader.proto

syntax = "proto2";

package object_detection.protos;

// Configuration proto for defining input readers that generate Object Detection
// Examples from input sources. Input readers are expected to generate a
// dictionary of tensors, with the following fields populated:
//
// 'image': an [image_height, image_width, channels] image tensor that detection
//    will be run on.
// 'groundtruth_classes': a [num_boxes] int32 tensor storing the class
//    labels of detected boxes in the image.
// 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of
//    detected boxes in the image.
// 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height,
//    image_width] float tensor storing binary mask of the objects in boxes.

// Instance mask format. Note that PNG masks are much more space efficient.
enum InstanceMaskType {
  DEFAULT = 0;          // Default implementation, currently NUMERICAL_MASKS
  NUMERICAL_MASKS = 1;  // [num_masks, H, W] float32 binary masks.
  PNG_MASKS = 2;        // Encoded PNG masks.
}

message InputReader {
  // Path to StringIntLabelMap pbtxt file specifying the mapping from string
  // labels to integer ids.
  optional string label_map_path = 1 [default=""];

  // Whether data should be processed in the order they are read in, or
  // shuffled randomly.
  optional bool shuffle = 2 [default=true];

  // Buffer size to be used when shuffling.
  optional uint32 shuffle_buffer_size = 11 [default = 2048];

  // Buffer size to be used when shuffling file names.
  optional uint32 filenames_shuffle_buffer_size = 12 [default = 100];

  // Maximum number of records to keep in reader queue.
  optional uint32 queue_capacity = 3 [default=2000];

  // Minimum number of records to keep in reader queue. A large value is needed
  // to generate a good random shuffle.
  optional uint32 min_after_dequeue = 4 [default=1000];

  // The number of times a data source is read. If set to zero, the data source
  // will be reused indefinitely.
  optional uint32 num_epochs = 5 [default=0];

  // Number of reader instances to create.
  optional uint32 num_readers = 6 [default=32];

  // Number of records to read from each reader at once.
  optional uint32 read_block_length = 15 [default=32];

  // Number of decoded records to prefetch before batching.
  optional uint32 prefetch_size = 13 [default = 512];

  // Number of parallel decode ops to apply.
  optional uint32 num_parallel_map_calls = 14 [default = 64];

  // Number of groundtruth keypoints per object.
  optional uint32 num_keypoints = 16 [default = 0];

  // Whether to load groundtruth instance masks.
  optional bool load_instance_masks = 7 [default = false];

  // Type of instance mask.
  optional InstanceMaskType mask_type = 10 [default = NUMERICAL_MASKS];

  // Whether to use the display name when decoding examples. This is only used
  // when mapping class text strings to integers.
  optional bool use_display_name = 17 [default = false];

  oneof input_reader {
    TFRecordInputReader tf_record_input_reader = 8;
    ExternalInputReader external_input_reader = 9;
  }
}

// An input reader that reads TF Example protos from local TFRecord files.
message TFRecordInputReader {
  // Path(s) to `TFRecordFile`s.
  repeated string input_path = 1;
}

// An externally defined input reader. Users may define an extension to this
// proto to interface their own input readers.
message ExternalInputReader {
  extensions 1 to 999;
}

data_decoders/tf_example_decoder.py

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tensorflow Example proto decoder for object detection.

A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow as tf

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util

slim_example_decoder = tf.contrib.slim.tfexample_decoder


# TODO(lzc): keep LookupTensor and BackupHandler in sync with
# tf.contrib.slim.tfexample_decoder version.
class LookupTensor(slim_example_decoder.Tensor):
  """An ItemHandler that returns a parsed Tensor, the result of a lookup."""

  def __init__(self,
               tensor_key,
               table,
               shape_keys=None,
               shape=None,
               default_value=''):
    """Initializes the LookupTensor handler.

    Simply calls a vocabulary (most often, a label mapping) lookup.

    Args:
      tensor_key: the name of the `TFExample` feature to read the tensor from.
      table: A tf.lookup table.
      shape_keys: Optional name or list of names of the TF-Example feature in
        which the tensor shape is stored. If a list, then each corresponds to
        one dimension of the shape.
      shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
        reshaped accordingly.
      default_value: The value used when the `tensor_key` is not found in a
        particular `TFExample`.

    Raises:
      ValueError: if both `shape_keys` and `shape` are specified.
    """
    self._table = table
    super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
                                       default_value)

  def tensors_to_item(self, keys_to_tensors):
    unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
    return self._table.lookup(unmapped_tensor)


class BackupHandler(slim_example_decoder.ItemHandler):
  """An ItemHandler that tries two ItemHandlers in order."""

  def __init__(self, handler, backup):
    """Initializes the BackupHandler handler.

    If the first Handler's tensors_to_item returns a Tensor with no elements,
    the second Handler is used.

    Args:
      handler: The primary ItemHandler.
      backup: The backup ItemHandler.

    Raises:
      ValueError: if either is not an ItemHandler.
    """
    if not isinstance(handler, slim_example_decoder.ItemHandler):
      raise ValueError('Primary handler is of type %s instead of ItemHandler' %
                       type(handler))
    if not isinstance(backup, slim_example_decoder.ItemHandler):
      raise ValueError(
          'Backup handler is of type %s instead of ItemHandler' % type(backup))
    self._handler = handler
    self._backup = backup
    super(BackupHandler, self).__init__(handler.keys + backup.keys)

  def tensors_to_item(self, keys_to_tensors):
    item = self._handler.tensors_to_item(keys_to_tensors)
    return control_flow_ops.cond(
        pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
        true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
        false_fn=lambda: item)


class TfExampleDecoder(data_decoder.DataDecoder):
  """Tensorflow Example proto decoder."""
  #初始化指定需要解碼哪些字段,以及對應的handler
  def __init__(self,
               load_instance_masks=False,
               instance_mask_type=input_reader_pb2.NUMERICAL_MASKS,
               label_map_proto_file=None,
               use_display_name=False,
               dct_method='',
               num_keypoints=0,
               num_additional_channels=0):
    """Constructor sets keys_to_features and items_to_handlers.

    Args:
      load_instance_masks: whether or not to load and handle instance masks.
      instance_mask_type: type of instance masks. Options are provided in
        input_reader.proto. This is only used if `load_instance_masks` is True.
      label_map_proto_file: a file path to a
        object_detection.protos.StringIntLabelMap proto. If provided, then the
        mapped IDs of 'image/object/class/text' will take precedence over the
        existing 'image/object/class/label' ID.  Also, if provided, it is
        assumed that 'image/object/class/text' will be in the data.
      use_display_name: whether or not to use the `display_name` for label
        mapping (instead of `name`).  Only used if label_map_proto_file is
        provided.
      dct_method: An optional string. Defaults to None. It only takes
        effect when image format is jpeg, used to specify a hint about the
        algorithm used for jpeg decompression. Currently valid values
        are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
        example, the jpeg library does not have that specific option.
      num_keypoints: the number of keypoints per object.
      num_additional_channels: how many additional channels to use.

    Raises:
      ValueError: If `instance_mask_type` option is not one of
        input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or
        input_reader_pb2.PNG_MASKS.
    """
    self.keys_to_features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/key/sha256':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/source_id':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=1),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=1),
        # Object boxes and classes.
        'image/object/bbox/xmin':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin':
            tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax':
            tf.VarLenFeature(tf.float32),
        'image/object/class/label':
            tf.VarLenFeature(tf.int64),
        'image/object/class/text':
            tf.VarLenFeature(tf.string),
        'image/object/area':
            tf.VarLenFeature(tf.float32),
        'image/object/is_crowd':
            tf.VarLenFeature(tf.int64),
        'image/object/difficult':
            tf.VarLenFeature(tf.int64),
        'image/object/group_of':
            tf.VarLenFeature(tf.int64),
        'image/object/weight':
            tf.VarLenFeature(tf.float32),
    }
    # We are checking `dct_method` instead of passing it directly in order to
    # ensure TF version 1.6 compatibility.
    # 如果是 JPEG 圖片,可以加速。
    if dct_method:
      image = slim_example_decoder.Image(
          image_key='image/encoded',
          format_key='image/format',
          channels=3,
          dct_method=dct_method)
      additional_channel_image = slim_example_decoder.Image(
          image_key='image/additional_channels/encoded',
          format_key='image/format',
          channels=1,
          repeated=True,
          dct_method=dct_method)
    else:
        # image 對象主要用來解碼圖片,
        # 如果 image/format 爲 raw 或 Raw 調用 parsing_ops.decode_raw 解碼
        # 如果 image/format 爲 jpeg,調用 image_ops.decode_jpeg 解碼
        # 否則調用 image_ops.decode_image 解碼圖片
      image = slim_example_decoder.Image(
          image_key='image/encoded', format_key='image/format', channels=3)
      additional_channel_image = slim_example_decoder.Image(
          image_key='image/additional_channels/encoded',
          format_key='image/format',
          channels=1,
          repeated=True)
    self.items_to_handlers = {
        fields.InputDataFields.image:
            image,
        fields.InputDataFields.source_id: (
            slim_example_decoder.Tensor('image/source_id')),
        fields.InputDataFields.key: (
            slim_example_decoder.Tensor('image/key/sha256')),
        fields.InputDataFields.filename: (
            slim_example_decoder.Tensor('image/filename')),
        # Object boxes and classes.
        fields.InputDataFields.groundtruth_boxes: (
            slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
                                             'image/object/bbox/')),
        fields.InputDataFields.groundtruth_area:
            slim_example_decoder.Tensor('image/object/area'),
        fields.InputDataFields.groundtruth_is_crowd: (
            slim_example_decoder.Tensor('image/object/is_crowd')),
        fields.InputDataFields.groundtruth_difficult: (
            slim_example_decoder.Tensor('image/object/difficult')),
        fields.InputDataFields.groundtruth_group_of: (
            slim_example_decoder.Tensor('image/object/group_of')),
        fields.InputDataFields.groundtruth_weights: (
            slim_example_decoder.Tensor('image/object/weight')),
    }
    if num_additional_channels > 0:
      self.keys_to_features[
          'image/additional_channels/encoded'] = tf.FixedLenFeature(
              (num_additional_channels,), tf.string)
      self.items_to_handlers[
          fields.InputDataFields.
          image_additional_channels] = additional_channel_image
    self._num_keypoints = num_keypoints
    if num_keypoints > 0:
      self.keys_to_features['image/object/keypoint/x'] = (
          tf.VarLenFeature(tf.float32))
      self.keys_to_features['image/object/keypoint/y'] = (
          tf.VarLenFeature(tf.float32))
      self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = (
          slim_example_decoder.ItemHandlerCallback(
              ['image/object/keypoint/y', 'image/object/keypoint/x'],
              self._reshape_keypoints))
    if load_instance_masks:
      if instance_mask_type in (input_reader_pb2.DEFAULT,
                                input_reader_pb2.NUMERICAL_MASKS):
        self.keys_to_features['image/object/mask'] = (
            tf.VarLenFeature(tf.float32))
        self.items_to_handlers[
            fields.InputDataFields.groundtruth_instance_masks] = (
                slim_example_decoder.ItemHandlerCallback(
                    ['image/object/mask', 'image/height', 'image/width'],
                    self._reshape_instance_masks))
      elif instance_mask_type == input_reader_pb2.PNG_MASKS:
        self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
        self.items_to_handlers[
            fields.InputDataFields.groundtruth_instance_masks] = (
                slim_example_decoder.ItemHandlerCallback(
                    ['image/object/mask', 'image/height', 'image/width'],
                    self._decode_png_instance_masks))
      else:
        raise ValueError('Did not recognize the `instance_mask_type` option.')
    if label_map_proto_file:
      label_map = label_map_util.get_label_map_dict(label_map_proto_file,
                                                    use_display_name)
      # We use a default_value of -1, but we expect all labels to be contained
      # in the label map.
      table = tf.contrib.lookup.HashTable(
          initializer=tf.contrib.lookup.KeyValueTensorInitializer(
              keys=tf.constant(list(label_map.keys())),
              values=tf.constant(list(label_map.values()), dtype=tf.int64)),
          default_value=-1)
      # If the label_map_proto is provided, try to use it in conjunction with
      # the class text, and fall back to a materialized ID.
      # TODO(lzc): note that here we are using BackupHandler defined in this
      # file(which is branching slim_example_decoder.BackupHandler). Need to
      # switch back to slim_example_decoder.BackupHandler once tf 1.5 becomes
      # more popular.
      label_handler = BackupHandler(
          LookupTensor('image/object/class/text', table, default_value=''),
          slim_example_decoder.Tensor('image/object/class/label'))
    else:
      label_handler = slim_example_decoder.Tensor('image/object/class/label')
    self.items_to_handlers[
        fields.InputDataFields.groundtruth_classes] = label_handler

  # decode 方法調用 parsing_ops.parse_single_example 解碼對應的字段,調用對應字段的 hander 解碼
  def decode(self, tf_example_string_tensor):
    """Decodes serialized tensorflow example and returns a tensor dictionary.

    Args:
      tf_example_string_tensor: a string tensor holding a serialized tensorflow
        example proto.

    Returns:
      A dictionary of the following tensors.
      fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
        containing image.
      fields.InputDataFields.source_id - string tensor containing original
        image id.
      fields.InputDataFields.key - string tensor with unique sha256 hash key.
      fields.InputDataFields.filename - string tensor with original dataset
        filename.
      fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
        [None, 4] containing box corners.
      fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
        [None] containing classes for the boxes.
      fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
        shape [None] indicating the weights of groundtruth boxes.
      fields.InputDataFields.num_groundtruth_boxes - int32 scalar indicating
        the number of groundtruth_boxes.
      fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
        [None] containing containing object mask area in pixel squared.
      fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
        [None] indicating if the boxes enclose a crowd.

    Optional:
      fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
        shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
        is width; 3rd dim is the number of additional channels.
      fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
        [None] indicating if the boxes represent `difficult` instances.
      fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
        [None] indicating if the boxes represent `group_of` instances.
      fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
        shape [None, None, 2] containing keypoints, where the coordinates of
        the keypoints are ordered (y, x).
      fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
        shape [None, None, None] containing instance masks.
    """
    # 創建一個 TFExampleDecoder 對象,初始化需要解碼的字段以及對應字段的 handler
    serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
                                                    self.items_to_handlers)
    keys = decoder.list_items()
    # 調用 parsing_ops.parse_single_example 將  String 類型的 Tensor 進行解碼。
    # 需要解碼的字段由 keys 指定
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))
    is_crowd = fields.InputDataFields.groundtruth_is_crowd
    tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
    tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
    tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
        tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]

    if fields.InputDataFields.image_additional_channels in tensor_dict:
      channels = tensor_dict[fields.InputDataFields.image_additional_channels]
      channels = tf.squeeze(channels, axis=3)
      channels = tf.transpose(channels, perm=[1, 2, 0])
      tensor_dict[fields.InputDataFields.image_additional_channels] = channels

    def default_groundtruth_weights():
      return tf.ones(
          [tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
          dtype=tf.float32)

    tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
        tf.greater(
            tf.shape(
                tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
            0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
        default_groundtruth_weights)
    return tensor_dict

  def _reshape_keypoints(self, keys_to_tensors):
    """Reshape keypoints.

    The instance segmentation masks are reshaped to [num_instances,
    num_keypoints, 2].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values
        in {0, 1}.
    """
    y = keys_to_tensors['image/object/keypoint/y']
    if isinstance(y, tf.SparseTensor):
      y = tf.sparse_tensor_to_dense(y)
    y = tf.expand_dims(y, 1)
    x = keys_to_tensors['image/object/keypoint/x']
    if isinstance(x, tf.SparseTensor):
      x = tf.sparse_tensor_to_dense(x)
    x = tf.expand_dims(x, 1)
    keypoints = tf.concat([y, x], 1)
    keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
    return keypoints

  def _reshape_instance_masks(self, keys_to_tensors):
    """Reshape instance segmentation masks.

    The instance segmentation masks are reshaped to [num_instances, height,
    width].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, height, width] with values
        in {0, 1}.
    """
    height = keys_to_tensors['image/height']
    width = keys_to_tensors['image/width']
    to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
    masks = keys_to_tensors['image/object/mask']
    if isinstance(masks, tf.SparseTensor):
      masks = tf.sparse_tensor_to_dense(masks)
    masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
    return tf.cast(masks, tf.float32)

  def _decode_png_instance_masks(self, keys_to_tensors):
    """Decode PNG instance segmentation masks and stack into dense tensor.

    The instance segmentation masks are reshaped to [num_instances, height,
    width].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, height, width] with values
        in {0, 1}.
    """

    def decode_png_mask(image_buffer):
      image = tf.squeeze(
          tf.image.decode_image(image_buffer, channels=1), axis=2)
      image.set_shape([None, None])
      image = tf.to_float(tf.greater(image, 0))
      return image

    png_masks = keys_to_tensors['image/object/mask']
    height = keys_to_tensors['image/height']
    width = keys_to_tensors['image/width']
    if isinstance(png_masks, tf.SparseTensor):
      png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
    return tf.cond(
        tf.greater(tf.size(png_masks), 0),
        lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
        lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))

參考:https://blog.csdn.net/wenxueliu/article/details/80727563

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章