【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型

本文主要讲解在现有常用模型基础上,如何微调模型,减少训练时间,同时保持模型检测精度。

首先介绍下Slim这个Google公布的图像分类工具包,可在github链接:modules and examples built with tensorflow 中找到slim包。

上面这个链接目录下主要包含:

official models(这个是用Tensorflow高层API做的例子模型集,建议初学者可尝试);

research models(这个是很多研究者利用tensorflow做的模型集,这个不是官方提供的,是研究者个人在维护的);

samples folder (包含代码片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代码呈现);

而我说的slim工具包就在research文件夹下。


Slim库结构

不仅定义了很多接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。

 


下面用slim工具包中的文件来对自己的数据集做训练,训练可分为利用已有的模型架构(如常见的VGG,Inception等的卷积,池化这些结构)来全新训练权重文件或是微调权重文件。由于很多已有的imagenet图像数据覆盖面已经很广,基于此训练的网络权重已经能提取大致的目标特征(从低微像素到高维的结构特征),所以可使用fine-tune只训练框架中某些层的权重,当然根据自己数据集做全部权重重新训练的检测效果理论会更好些,需要权衡时间成本和检测精度的需求了;

下面会依据成熟网络结构Incvption V3分别做权重文件的全部重新训练部分重新训练(即fine-tune)来介绍;

(前提是你将slim工具库下载下来,安装了必要的tensorflow等框架;并且根据训练图像制作完成了tfrecord文件)

有关tfrecord训练文件的制作请参考:将图像制作成tfrecord

step1:定义新的datasets数据集文件

在slim/datasets/文件夹下 添加一个python文件,直接复制一份flowers.py,重命名为“satellite.py”(这个名字可根据你实际的数据集名字来更改,我用的是何大神的航拍图数据集)

需要对赋值生成后的satellite.py内容做如下修改:

_FILE_PATTERN = 'flowers_%s_*.tfrecord' 

更改为

_FILE_PATTERN = 'satellite_%s_*.tfrecord'      #这个主要是根据你之前制作的tfrecord文件名来改的,我制作的训练文件为satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,验证文件为satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord

SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}

更改为

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}  #这个根据自己训练和验证样本数量来改,我的训练数据是800张图/类,共6类,验证集时200张/类,共6类;

_NUM_CLASSES = 5

更改为

_NUM_CLASSES = 6       #实际训练类别为6类;

 

还需要对satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),这行代码做更改,由于用的数据集源文件都是XXXX.jpg格式,因此将默认的图像格式转为jpg,更改后为'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,对satellite.py文件完成制作与更改(其源码如下):

satellite.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.

The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'satellite_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}

_NUM_CLASSES = 6

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)

step2:注册数据库

接下来对slim/datasets/dataset_factory.py文件做更改,注册下satellite数据库;修改之处如下(添加了两行红色字体代码):

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'satellite': satellite,
    
}

step3:准备训练文件夹

在slim文件夹下新建如下目录文件夹,并将对应的文件放在相应目录下

slim/
    satellite/
              data/
                   satellite_train_00000-of-00002.tfrecord
                   satellite_train_00001-of-00002.tfrecord
                   satellite_validation_00000-of-00002.tfrecord
                   satellite_validation_00001-of-00002.tfrecord
                   label.txt
              pretrained/
                   inception_v3.ckpt
              train_dir/

data文件夹下存放你制作的tfrecord训练测试文件和标签名;

pretrained文件夹下存放官网训练的权重文件;下载地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz      

train_dir文件夹下存放你训练得到的模型和日志;

step4-1:在现有模型结构上fine-tune

开始训练,在slim文件夹下,运行如下指令可开始训练(主要是训练逻辑层):

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

命令参数解析如下:

--trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先来解 释参数trainable_scopes 的作用,因为非常重要。 trainable_scopes 规定了在模型中fine-tune变量的范围 。 这里的设定表示只对 InceptionV3/Logits, Inception V3/ AuxLogits 两个变量进行微调,其他变量都保持不动 。 Inception V3/Logits,Inception V3/ AuxLogits 就相当于在网络中的 fc8 ,它们是 Inception V3的“末端层” 。 如果不设定 trainable_scopes , 就会对模型中所有的参数进行训练。

• --train_dir=satellite/train_dir:表明会在 satellite/train_dir目录下保存日志和checkpoint。

--dataset_name=satellite、 --dataset_split_ name=train: 指定训练的数据集。

--dataset_dit=satellite/data:指定训练数据集保存的位置。 

--model_ name=inception _ v3 :使用的模型名称。 

--checkpoint_path=satellite/pretrained/inception_v3.ckpt:预训练模型的保存位置。

--checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是 Inception V3 模型的末端层,对应着 ImageNet 数据集的 1000 类,和相当前的数据集不符,因此不要去恢复它。

--max_number_of_steps 100000:最大的执行步数。

--batch_size=32:每步使用的 batch 数量。

--learning_rate=0.001 : 学习率。

• --learning_rate_decay_type=fixed:学习率是否自动下降,此处使用固定的学习率。

• --save_interval_secs=300:每隔 300s,程序会把当前模型保存到train_dir中。 此处就是目录 satellite/train_dir。

• --save_summaries_secs=2:每隔 2s,就会将日志写入到 train_dir 中。可以用 TensorBoard 查看该日志。此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔。

• --log_every_n_steps=10:每隔10步,就会在屏上打出训练信息。

--optimizer=msprop:表示选定的优化器。

• --weight_decay=0.00004:选定的 weight_decay 值。 即模型中所高参数的 二次正则化超参数。


以上命令是只训练末端层 InceptionV3/Logits,Inception V3/ AuxLogits ,还 可以使用以下命令对所高层进行训练:

step4-2:训练整个模型权重数据

使用以下命令对所有层进行训练:
去掉 了--trainable_scopes 参数

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

当train_image_classifier.py程序启动后,如果训练文件夹(即satellite/train_dir)里没再已经保存的模型,就会加载 checkpoint_path 中的预训练模型,紧接着,程序会把初始模型保存到 train_dir中 ,命名为 model.ckpt-0, 0 表示第 0 步。 这之后,每隔 5min (参数一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序还会把当前模型保存到同样的文件夹中 , 命名恪式和第一次保存的格式一样。 因为模型比较大,程序只会保留最新的 5 个模型。
此外,如果中断了程序并再次运行,程序会首先检查 train_dir 中有无已经保存的模型,如果有,就不会去加载 checkpoint_path 中的预训练模型, 而是直接加载 train_dir 中已经训练好的模型,并以此为起点进行训练。 Slim 之所以这样设计,是为了在微调网络的时候,可以方便地按阶段手动调整学习率等参数。
 

至此用slim工具包做fine-tune或重新训练的步骤就完成了。


相似文章参考:https://blog.csdn.net/chaipp0607/article/details/74139895

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章