tensorflow object detection api定義新的Feature Extractor

環境tensorflow 1.12, win10.

目的resnet 50網絡結構在某些情況下過深,增加tensorflow object detection api特徵提取器resnet 18

安裝obeject detection api

安裝步驟參考TensorFlow Models

protoc選擇3.4.0版本。

構建resnet 18

在這裏插入圖片描述
針對slim中的net,修改resnet 50得到resnet18

def resnet_v1_18(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 store_non_strided_activations=False,
                 min_base_depth=8,
                 depth_multiplier=1,
                 reuse=None,
                 scope='resnet_v1_18'):
  """ResNet-18 model of [1]. See resnet_v1() for arg and return description."""
  depth_func = lambda d: max(int(d * depth_multiplier), min_base_depth)
  blocks = [
      resnet_v1_block('block1', base_depth=depth_func(64), num_units=2,
                      stride=2),
      resnet_v1_block('block2', base_depth=depth_func(128), num_units=2,
                      stride=2),
      resnet_v1_block('block3', base_depth=depth_func(256), num_units=2,
                      stride=2),
      resnet_v1_block('block4', base_depth=depth_func(512), num_units=2,
                      stride=1),
  ]
  return resnet_v1(inputs, blocks, num_classes, is_training,
                   global_pool=global_pool, output_stride=output_stride,
                   include_root_block=True, spatial_squeeze=spatial_squeeze,
                   store_non_strided_activations=store_non_strided_activations,
                   reuse=reuse, scope=scope)
resnet_v1_18.default_image_size = resnet_v1.default_image_size

定義新的Feature Extractor

參考官方給出的文檔 define your own model

class FasterRCNNResnet18FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
  """Faster R-CNN Resnet 18 feature extractor implementation."""

  def __init__(self,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0):
    """Constructor.

    Args:
      is_training: See base class.
      first_stage_features_stride: See base class.
      batch_norm_trainable: See base class.
      reuse_weights: See base class.
      weight_decay: See base class.

    Raises:
      ValueError: If `first_stage_features_stride` is not 8 or 16,
        or if `architecture` is not supported.
    """
    super(FasterRCNNResnet18FeatureExtractor, self).__init__(
        'resnet_v1_18', resnet_v1.resnet_v1_18, is_training,
        first_stage_features_stride, batch_norm_trainable,
        reuse_weights, weight_decay)

FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
    'faster_rcnn_nas':
    frcnn_nas.FasterRCNNNASFeatureExtractor,
    'faster_rcnn_pnas':
    frcnn_pnas.FasterRCNNPNASFeatureExtractor,
    'faster_rcnn_inception_resnet_v2':
    frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
    'faster_rcnn_inception_v2':
    frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
    'faster_rcnn_resnet18':
    frcnn_resnet_v1.FasterRCNNResnet18FeatureExtractor,
    'faster_rcnn_resnet50':
    frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
    'faster_rcnn_resnet101':
    frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
    'faster_rcnn_resnet152':
    frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}
    feature_extractor {
      type: 'faster_rcnn_resnet18'
      first_stage_features_stride: 16    # RPN steps
    }

因爲沒有提供預訓練的模型,註釋config中的以下兩個設置

 # fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
 # from_detection_checkpoint: true
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章