Tensorflow图像生成文本(3)图像特征的提取

前言

对于图像生成文本来说,除了预处理文本信息,还要对图像特征进行提前提取。这里单独编写一个脚本,对所有图像特征进行提取,然后将提取出来的特征保存在一个目录中。

代码实现

这里使用到了预训练好的 inception_v3 模型,并且是一个带有默认图的 inception_v3 模型。inception_v3_graph_def.pb 该模型不光带有模型参数,同时还带有计算图,将这个计算图载入到自己的计算图中,连编写计算图的过程也给省去了,就是方便。然后就可以直接使用这个载入进来的计算图进行图像特征的提取了。

import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np

# 卷积神经网络 路径
model_file = "./data/checkpoint_inception_v3/inception_v3_graph_def.pb"
# 图像描述文件
input_description_file = "./data/results_20130124.token"
# 图像样本保存目录
input_img_dir = "./data/flickr30k_images/"
# 图像特征提取 保存目录
output_folder = "./data/download_inception_v3_features"

batch_size = 1000 # 

if not gfile.Exists(output_folder):
    gfile.MakeDirs(output_folder)


def parse_token_file(token_file):
    '''
    做一个 从图像到描述的字典 {图像:[描述1, 描述2, ……]}
    :param token_file: token文件
    :return: 字典
    '''
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()

    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)

    return img_name_to_tokens


img_name_to_tokens = parse_token_file(input_description_file)
all_img_names = list(img_name_to_tokens.keys()) # 获得图像名称

logging.info('num of all images: %d' % len(all_img_names)) # 图像总数

def load_pretrained_inception_v3(model_file):
    '''
    导入预训练好的计算图(这里需要重点关注一下)
    :param model_file: 计算图路径
    :return: none
    '''
    with  gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef() # 建立一个空的计算图
        graph_def.ParseFromString(f.read()) # 将文件内容解析到这个空的计算图中去

        _ = tf.import_graph_def(graph_def, name='') # 将计算图导入到默认的计算图中去

load_pretrained_inception_v3(model_file)

# 确定 batch_size 的 个数
num_batches = int(len(all_img_names) / batch_size)
if len(all_img_names) % batch_size != 0: # 也就是说,无法整除的话
    num_batches += 1

with tf.Session() as sess:
    # 通过名称 拿出某一层的 特征图
    second_to_last_tensor = sess.graph.get_tensor_by_name("pool_3:0")
    for i in range(num_batches):
        batch_img_names = all_img_names[i*batch_size: (i+1)*batch_size]
        batch_features = []
        for img_name in batch_img_names:
            img_path = os.path.join(input_img_dir, img_name)

            if not gfile.Exists(img_path):
                raise Exception("%s doesn't exists" % img_path)

            logging.info('processing img %s' % img_name)

            # tf.gfile.FastGFile(path, decodestyle)
            # 函数功能:实现对图片的读取。
            if not gfile.Exists(img_path):
                raise Exception("%s doesn't exists" % img_path)
            img_data = gfile.FastGFile(img_path, 'rb').read()
            feature_vector = sess.run(second_to_last_tensor,
                                      feed_dict={
                                          'DecodeJpeg/contents:0': img_data
                                      }
                                      )

            # 此刻的 feature_vector : Tensor("pool_3:0", shape=(1, 1, 1, 2048), dtype=float32)
            batch_features.append(feature_vector)

        batch_features = np.vstack((batch_features))
        output_filename = os.path.join(
            output_folder,
            'image_features-%d.pickle' % i
        )
        logging.info('writing to file %s' % output_filename)
        with gfile.GFile(output_filename, 'w') as f:
            pickle.dump((batch_img_names, batch_features), f)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章