java調用文本分類textrnn模型,勿踩坑

直接貼代碼
# 將模型保存爲可用於線上服務的文件(一個.pb文件,一個variables文件夾)
# print('Exporting trained model to', save_dir)
builder = tf.saved_model.builder.SavedModelBuilder(save_dir)

# 服務器專用代碼

classification_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={
            # "image"
            "input_x":
                tf.saved_model.utils.build_tensor_info(rnn.input_x),
            "dropout_keep_prob":
                tf.saved_model.utils.build_tensor_info(rnn.dropout_keep_prob)
        },
        outputs={
            # "classify"
            "output":
                tf.saved_model.utils.build_tensor_info(rnn.predictions)
            # "classification_outputs_scores":
            #     tf.saved_model.utils.build_tensor_info(model.logits)
        },
        # Prediction method name used in a SignatureDef.
        # PREDICT_METHOD_NAME = "tensorflow/serving/predict"
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

builder.add_meta_graph_and_variables(
    # saved_model.tag_constants.SERVING = "saved_model.tag_constants.SERVING"
    sess, [tf.saved_model.tag_constants.SERVING],
    # 保存模型的方法名,與客戶端的request.model_spec.signature_name對應
    signature_def_map={
        # "predict_image"
        "classification":
            classification_signature},
    legacy_init_op=legacy_init_op)

builder.save()

1、"input_x","output"這個千萬不要亂寫,因爲你java調用的時候必須前後這個命名一致,否則會導致java調用模型預測結果與 python模型結果存在很大的差別

2、rnn是你的模型,rnn = TextRnn(config)

3、rnn.dropout_keep_prob是你的drop的命名方式,一定得和後續的一致

4、rnn.dropout_keep_prob與3一樣

5、python跑模型的tensorflow的版本必須和java調用的版本一樣!!!

模型格式如下:

現在模型準備好了就開始java調用了:

SavedModelBundle modelBundle = SavedModelBundle.load(path,"serve");

Session tfSession = modelBundle.session();

Operation operationPredict = modelBundle.graph().operation("output/predictions");

Output output = new Output(operationPredict,0);

Tensor keep_prob = Tensor.create(Float.parseFloat("1.0"));

“path”是你模型保存的路徑

"output/predictions"和python中的命令相對應,一定得一樣,不要亂命名,例如output_y,絕對結果出錯

下一步對於輸入“很幸運遇見你”,python獲得的word_to_index文件把輸入轉換成相對應的位置標籤a,

Tensor input_x = Tensor.create(a);
Tensor out = tfSession.runner().feed("input_x", input_x).feed("dropout_keep_prob",keep_prob).fetch(output).run().get(0);

轉成輸入的tensor,a是二維向量;"dropout_keep_prob"與之前的相對應,不要亂寫!!!,keep_prob預測的時候就設置成1吧,訓練的時候可以隨機關閉一半左右,但測試的時候你需要全用的。

long[] temp = new long[1];
out.copyTo(temp);
short reskey = (short) temp[0];

獲取對應的分類座標,你訓練的時候會獲得每個類別對應的座標,然後根據上面獲得的reskey去獲得相應的類別就ok了!

 

 

總結以上主要有幾個點:

   1、 不要python生成pb時的參數命名和java調用的時候不一致

   2、python和java的tensorflow版本必須一致,1.10和1.12都會報錯

   3、輸入轉換成座標向量和類別座標這兩個map中的對應順序不要錯了

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章