Tensorflow圖像生成文本(3)圖像特徵的提取

前言

對於圖像生成文本來說,除了預處理文本信息,還要對圖像特徵進行提前提取。這裏單獨編寫一個腳本,對所有圖像特徵進行提取,然後將提取出來的特徵保存在一個目錄中。

代碼實現

這裏使用到了預訓練好的 inception_v3 模型,並且是一個帶有默認圖的 inception_v3 模型。inception_v3_graph_def.pb 該模型不光帶有模型參數,同時還帶有計算圖,將這個計算圖載入到自己的計算圖中,連編寫計算圖的過程也給省去了,就是方便。然後就可以直接使用這個載入進來的計算圖進行圖像特徵的提取了。

import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np

# 卷積神經網絡 路徑
model_file = "./data/checkpoint_inception_v3/inception_v3_graph_def.pb"
# 圖像描述文件
input_description_file = "./data/results_20130124.token"
# 圖像樣本保存目錄
input_img_dir = "./data/flickr30k_images/"
# 圖像特徵提取 保存目錄
output_folder = "./data/download_inception_v3_features"

batch_size = 1000 # 

if not gfile.Exists(output_folder):
    gfile.MakeDirs(output_folder)


def parse_token_file(token_file):
    '''
    做一個 從圖像到描述的字典 {圖像:[描述1, 描述2, ……]}
    :param token_file: token文件
    :return: 字典
    '''
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()

    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)

    return img_name_to_tokens


img_name_to_tokens = parse_token_file(input_description_file)
all_img_names = list(img_name_to_tokens.keys()) # 獲得圖像名稱

logging.info('num of all images: %d' % len(all_img_names)) # 圖像總數

def load_pretrained_inception_v3(model_file):
    '''
    導入預訓練好的計算圖(這裏需要重點關注一下)
    :param model_file: 計算圖路徑
    :return: none
    '''
    with  gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef() # 建立一個空的計算圖
        graph_def.ParseFromString(f.read()) # 將文件內容解析到這個空的計算圖中去

        _ = tf.import_graph_def(graph_def, name='') # 將計算圖導入到默認的計算圖中去

load_pretrained_inception_v3(model_file)

# 確定 batch_size 的 個數
num_batches = int(len(all_img_names) / batch_size)
if len(all_img_names) % batch_size != 0: # 也就是說,無法整除的話
    num_batches += 1

with tf.Session() as sess:
    # 通過名稱 拿出某一層的 特徵圖
    second_to_last_tensor = sess.graph.get_tensor_by_name("pool_3:0")
    for i in range(num_batches):
        batch_img_names = all_img_names[i*batch_size: (i+1)*batch_size]
        batch_features = []
        for img_name in batch_img_names:
            img_path = os.path.join(input_img_dir, img_name)

            if not gfile.Exists(img_path):
                raise Exception("%s doesn't exists" % img_path)

            logging.info('processing img %s' % img_name)

            # tf.gfile.FastGFile(path, decodestyle)
            # 函數功能:實現對圖片的讀取。
            if not gfile.Exists(img_path):
                raise Exception("%s doesn't exists" % img_path)
            img_data = gfile.FastGFile(img_path, 'rb').read()
            feature_vector = sess.run(second_to_last_tensor,
                                      feed_dict={
                                          'DecodeJpeg/contents:0': img_data
                                      }
                                      )

            # 此刻的 feature_vector : Tensor("pool_3:0", shape=(1, 1, 1, 2048), dtype=float32)
            batch_features.append(feature_vector)

        batch_features = np.vstack((batch_features))
        output_filename = os.path.join(
            output_folder,
            'image_features-%d.pickle' % i
        )
        logging.info('writing to file %s' % output_filename)
        with gfile.GFile(output_filename, 'w') as f:
            pickle.dump((batch_img_names, batch_features), f)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章