如何合併兩個TensorFlow模型

這是Tensorflow SavedModel模型系列文章的第三篇,也是終章。在《Tensorflow SavedModel模型的保存與加載》中,我們談到了Tensorflow模型如何保存爲SavedModel格式,以及如何加載之。在《如何查看tensorflow SavedModel格式模型的信息》中,我們演示瞭如何查看模型的signature和計算圖結構。在本文中,我們將探討如何合併兩個模型,簡單的說,就是將第一個模型的輸出,作爲第二個模型的輸入,串聯起來形成一個新模型。

背景

爲什麼需要合併兩個模型?

我們還是以《Tensorflow SavedModel模型的保存與加載》中的代碼爲例,這個手寫數字識別模型接收的輸入是shape爲[?, 784],這裏?代表可以批量接收輸入,可以先忽略,就把它固定爲1吧。784是28 x 28進行展開的結果,也就是28 x 28灰度圖像展開的結果。

問題是,我們送給模型的通常是圖片,可能來自文件、可能來自攝像頭。讓問題變得複雜的是,如果我們通過HTTP來調用部署到服務器端的模型,二進制數據實際上是不方便HTTP傳輸的,這時我們通常需要對圖像數據進行base64編碼。這樣服務器端接收到的數據是一個base64字符串,可模型接受的是二進制向量。

很自然的,我們可以想到兩種解決方法:

  1. 重新訓練模型一個接收base64字符串的模型。

    這種解決方法的問題在於:重新訓練模型很費時,甚至不可行。本文示例因爲比較簡單,重新訓練也沒啥。如果是那種很深的卷積神經網絡,訓練一次可能需要好幾天,重新訓練代價很大。更普遍的情況是,我們使用的是別人訓練好的模型,比如圖像識別中普遍使用的Mobilenet、InceptionV3等等,都是Google、微軟這樣的公司,耗費大量的資源訓練出來的,我們沒有那個條件重新訓練。

  2. 在服務器端增加base64到二進制數據的轉換

    這種解決方法實現起來不復雜,但如果我們使用的是Tensorflow model server之類的方案部署的呢?當然我們也可以再開啓一個server,來接受客戶端的base64圖像數據,處理完畢之後再轉發給Tensorflow model server,但這無疑增加了服務端的工作量,增加了服務端的複雜性。

在本文,我們將給出第三種方案:編寫一個Tensorflow模型,接收base64的圖像數據,輸出二進制向量,然後將第一個模型的輸出作爲第二個模型的輸入,串接起來,保存爲一個新的模型,最後部署新的模型。

base64解碼Tensorflow模型

Tensorflow包含了大量圖像處理和數組處理的方法,所以實現這個模型比較簡單,模型包含了base64解碼、解碼PNG圖像、縮放到28 * 28、最後展開爲(1, 784)的數組輸出,符合手寫數字識別模型的輸入,代碼如下:

with tf.Graph().as_default() as g1:
  base64_str = tf.placeholder(tf.string, name='input_string')
  input_str = tf.decode_base64(base64_str)
  decoded_image = tf.image.decode_png(input_str, channels=1)
  # Convert from full range of uint8 to range [0,1] of float32.
  decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
                                                        tf.float32)
  decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
  resize_shape = tf.stack([28, 28])
  resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
  resized_image = tf.image.resize_bilinear(decoded_image_4d,
                                           resize_shape_as_int)
  # 展開爲1維數組
  resized_image_1d = tf.reshape(resized_image, (-1, 28 * 28))
  print(resized_image_1d.shape)
  tf.identity(resized_image_1d, name="DecodeJPGOutput")

g1def = g1.as_graph_def()

在該模型中,並不存在變量,都是一些固定的操作,所以無需進行訓練。

加載手寫識別模型

手寫識別模型參考《Tensorflow SavedModel模型的保存與加載》一文,模型保存在 “./model” 下,加載代碼如下:

with tf.Graph().as_default() as g2:
  with tf.Session(graph=g2) as sess:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        "./model", tag_constants.SERVING).graph_def

    tf.saved_model.loader.load(sess, ["serve"], "./model")

    g2def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        ["myOutput"],
        variable_names_whitelist=None,
        variable_names_blacklist=None)

這裏使用了g2定義了另外一個graph,和前面的模型的graph區分開來。注意這裏調用了graph_util.convert_variables_to_constants將模型中的變量轉化爲常量,也就是所謂的凍結圖(freeze graph)操作。

在研究如何連接兩個模型時,我在這個問題上卡了很久。先的想法是合併模型之後,再加載變量值進來,但是嘗試之後,怎麼也不成功。後來的想法是遍歷手寫識別模型的變量,獲取其變量值,將變量值複製到合併的模型的變量,但這樣操作,使用模型時,總是提示有變量未初始化。

最後從Tensorflow模型到Tensorflow lite模型轉換中獲得了靈感,將模型中的變量固定下來,這樣就不存在變量的加載問題,也不會出現模型變量未初始化的問題。

執行convert_variables_to_constants後,可以看到有兩個變量轉化爲了常量操作,也就是手寫數字識別模型中的w和b:

Converted 2 variables to const ops.

連接兩個模型

利用tf.import_graph_def方法,我們可以導入圖到現有圖中,注意第二個import_graph_def,其input是第一個graph_def的輸出,通過這樣的操作,就將兩個計算圖連接起來,最後保存起來。代碼如下:

with tf.Graph().as_default() as g_combined:
  with tf.Session(graph=g_combined) as sess:

    x = tf.placeholder(tf.string, name="base64_input")

    y, = tf.import_graph_def(g1def, input_map={"input_string:0": x}, return_elements=["DecodeJPGOutput:0"])

    z, = tf.import_graph_def(g2def, input_map={"myInput:0": y}, return_elements=["myOutput:0"])
    tf.identity(z, "myOutput")

    tf.saved_model.simple_save(sess,
              "./modelbase64",
              inputs={"base64_input": x},
              outputs={"myOutput": z})

因爲第一個模型不包含變量,第二個模型的變量轉化爲了常量操作,所以最後保存的模型文件並不包含變量:

modelbase64/
├── saved_model.pb
└── variables

1 directory, 1 file

測試

我們寫一段測試代碼,測試一下合併之後模型是否管用,代碼如下:

with tf.Session(graph=tf.Graph()) as sess:
  sess.run(tf.global_variables_initializer())

  tf.saved_model.loader.load(sess, ["serve"], "./modelbase64")
  graph = tf.get_default_graph()

  with open("./5.png", "rb") as image_file:
    encoded_string = str(base64.urlsafe_b64encode(image_file.read()), "utf-8")

  x = sess.graph.get_tensor_by_name('base64_input:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')

  scores = sess.run(y,
           feed_dict={x: encoded_string})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), 5))

這裏模型的輸入爲base64_input,輸出仍然是myOutput,使用兩個圖片測試,均工作正常。

小結

最近三篇文章其實都是在研究我的微信小程序時總結的,爲了更好的說明問題,我使用了一個非常簡單的模型來說明問題,但同樣適用於複雜的模型。

本文的完整代碼請參考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model

希望這篇文章對您有幫助,感謝閱讀!同時敬請關注我的微信公衆號:雲水木石。

image

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章