直接貼代碼 # 將模型保存爲可用於線上服務的文件(一個.pb文件,一個variables文件夾) # print('Exporting trained model to', save_dir) builder = tf.saved_model.builder.SavedModelBuilder(save_dir) # 服務器專用代碼 classification_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={ # "image" "input_x": tf.saved_model.utils.build_tensor_info(rnn.input_x), "dropout_keep_prob": tf.saved_model.utils.build_tensor_info(rnn.dropout_keep_prob) }, outputs={ # "classify" "output": tf.saved_model.utils.build_tensor_info(rnn.predictions) # "classification_outputs_scores": # tf.saved_model.utils.build_tensor_info(model.logits) }, # Prediction method name used in a SignatureDef. # PREDICT_METHOD_NAME = "tensorflow/serving/predict" method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') builder.add_meta_graph_and_variables( # saved_model.tag_constants.SERVING = "saved_model.tag_constants.SERVING" sess, [tf.saved_model.tag_constants.SERVING], # 保存模型的方法名,與客戶端的request.model_spec.signature_name對應 signature_def_map={ # "predict_image" "classification": classification_signature}, legacy_init_op=legacy_init_op) builder.save()
1、"input_x","output"這個千萬不要亂寫,因爲你java調用的時候必須前後這個命名一致,否則會導致java調用模型預測結果與 python模型結果存在很大的差別
2、rnn是你的模型,rnn = TextRnn(config)
3、rnn.dropout_keep_prob是你的drop的命名方式,一定得和後續的一致
4、rnn.dropout_keep_prob與3一樣
5、python跑模型的tensorflow的版本必須和java調用的版本一樣!!!
模型格式如下:
現在模型準備好了就開始java調用了:
SavedModelBundle modelBundle = SavedModelBundle.load(path,"serve"); Session tfSession = modelBundle.session(); Operation operationPredict = modelBundle.graph().operation("output/predictions"); Output output = new Output(operationPredict,0); Tensor keep_prob = Tensor.create(Float.parseFloat("1.0"));
“path”是你模型保存的路徑
"output/predictions"和python中的命令相對應,一定得一樣,不要亂命名,例如output_y,絕對結果出錯
下一步對於輸入“很幸運遇見你”,python獲得的word_to_index文件把輸入轉換成相對應的位置標籤a,
Tensor input_x = Tensor.create(a); Tensor out = tfSession.runner().feed("input_x", input_x).feed("dropout_keep_prob",keep_prob).fetch(output).run().get(0);
轉成輸入的tensor,a是二維向量;"dropout_keep_prob"與之前的相對應,不要亂寫!!!,keep_prob預測的時候就設置成1吧,訓練的時候可以隨機關閉一半左右,但測試的時候你需要全用的。
long[] temp = new long[1]; out.copyTo(temp); short reskey = (short) temp[0];
獲取對應的分類座標,你訓練的時候會獲得每個類別對應的座標,然後根據上面獲得的reskey去獲得相應的類別就ok了!
總結以上主要有幾個點:
1、 不要python生成pb時的參數命名和java調用的時候不一致
2、python和java的tensorflow版本必須一致,1.10和1.12都會報錯
3、輸入轉換成座標向量和類別座標這兩個map中的對應順序不要錯了