TensorFlow for Java
WARNING: The TensorFlow Java API is not currently covered by the TensorFlow API stability guarantees.
目前,TensorFlow Java API 不在 TensorFlow API 穩定性保證的範圍內。
For using TensorFlow on Android refer instead to TensorFlow Lite.
關於在Android上使用TensorFlow,請參考TensorFlow Lite。
使用 Java API 進行預測
private static Session loadSession() {
SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
Graph graph = new Graph(); //創建圖結構
InputStream is = getStreamFromPb("car_model.pb"); //加載本地 pb 文件到內存
byte[] graphBytes = new byte[0];
try {
graphBytes = IOUtils.toByteArray(is);
} catch (Exception e) {
e.printStackTrace();
}
graph.importGraphDef(graphBytes); //內存數據 ——> 圖結構
Session session = new Session(graph); //通過圖結構初始化會話
return session;
}
private static String faceEmbedding(Session session, String imagePath) {
float[][] embeddingsRes = new float[1][128];
try {
float[][][] rgbImage = readImage(imagePath);
float[][][] rgbWhitened = whiten(rgbImage);
float[][][][] rgbFloat = new float[1][160][160][3];
rgbFloat[0] = rgbWhitened;
Tensor<Float> imageTensor = Tensors.create(rgbFloat); //輸入 Tensor
Tensor phaseTensor = Tensor.create(new Boolean(Boolean.FALSE)); //輸入 Tensor
Session.Runner result = session.runner().feed("input", imageTensor).feed("phase_train", phaseTensor);
Tensor embeddings = result.fetch("embeddings").run().get(0); //執行圖,輸出 Tensor
System.out.println("embeddings.toString(): " + embeddings.toString());
embeddings.copyTo(embeddingsRes);
} catch (Exception e) {
e.printStackTrace();
}
JSONObject json = new JSONObject();
json.put(image_path, embeddingsRes[0]);
return json.toString();
}
參考:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java
https://www.jianshu.com/p/e11891418bc1