// 下載模型 private static final SavedModelBundle modelBundle = SavedModelBundle.load(模型路徑,"serve"); // session private static final Session tfSession = modelBundle.session(); // 預測 private static final Operation operationPredict = modelBundle.graph().operation("score/my_predict"); // 輸出 private static final Output output = new Output(operationPredict,0); // Tensor private static final Tensor keep_prob = Tensor.create(Float.parseFloat("0.6"));
// 輸入轉換成向量,a爲向量
Tensor input_x = Tensor.create(a); // 預測輸出,類目對應編號 Tensor out = tfSession.runner().feed("input_x", input_x).feed("keep_prob",keep_prob).fetch(output).run().get(0); long[] temp = new long[1]; out.copyTo(temp); short reskey = (short) temp[0];
調用線上模型:
// 輸入長度---並轉化成向量 List<Object> requestData = new ArrayList<>(Collections.nCopies(輸入向量長度,0));
// 預測結果
Object responseData = iWpaiDLPredictOnlineService.tensorflowServingPredictOnline(線上服務編號,requestData);