將TensorFlow訓練好的模型遷移到Android APP上(TensorFlowLite)
1. 寫在前面
最近在做一個數字手勢識別的APP(關於這個項目,我會再寫一篇博客仔細介紹,博客地址:一步步做一個數字手勢識別APP,源代碼已經開源在github上,地址:Chinese-number-gestures-recognition),要把在PC端訓練好的模型放到Android APP上,調研了下,谷歌發佈了TensorFlow Lite可以把TensorFlow訓練好的模型遷移到Android APP上,百度也發佈了移動端深度學習框架mobile-deep-learning(MDL),這個框架應該是paddlepaddle的手機版,具體的細節沒有了解過。因爲對TensorFlow稍微熟悉些,因此就決定用TensorFlow來做。
關於在PC端如何處理數據及訓練模型,請參見博客:一步步做一個數字手勢識別APP,代碼已經開源在github上,上面有代碼的說明和APP演示。這篇博客只介紹如何把TensorFlow訓練好的模型遷移到Android Studio上進行APP的開發。
2. 模型訓練注意事項
第一步,首先在pc端訓練模型的時候要模型保存爲.pb模型,在保存的時候有一點非常非常重要,就是你待會再Android studio是使用這個模型用到哪個參數,那麼你在保存pb模型的時候就把給哪個參數一個名字,再保存。否則,你在Android studio中很難拿出這個參數,因爲TensorFlow Lite的fetch()函數是根據保存在pb模型中的名字去尋找這個參數的。(如果你已經訓練好了模型,並且沒有給參數名字,且你不想再訓練模型了,那麼你可以嘗試下面的方法去找到你需要使用的變量的默認名字,見下面的代碼):
#輸出保存的模型中參數名字及對應的值
with tf.gfile.GFile('model_50_200_c3//./digital_gesture.pb', "rb") as f: #讀取模型數據
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) #得到模型中的計算圖和數據
with tf.Graph().as_default() as graph: # 這裏的Graph()要有括號,不然會報TypeError
tf.import_graph_def(graph_def, name="") #導入模型中的圖到現在這個新的計算圖中,不指定名字的話默認是 import
for op in graph.get_operations(): # 打印出圖中的節點信息
print(op.name, op.values())
這段代碼打出的變量的名字以及對應的值。
言歸正傳,通常情況該你應該保存參數的時候都給參數一個指定的名字,如下面這樣(通過name參數給變量指定名字),關於訓練CNN的完整代碼請參見下一篇博客或者github:
X = tf.placeholder(tf.float32, [None, 64, 64, 3], name="input_x")
y = tf.placeholder(tf.float32, [None, 11], name="input_y")
kp = tf.placeholder_with_default(1.0, shape=(), name="keep_prob")
lam = tf.placeholder(tf.float32, name="lamda")
#中間略過若干代碼
z_fc2 = tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")
prob = tf.nn.softmax(z_fc2, name="probability")
pred = tf.argmax(prob, 1, output_type="int32", name="predict")
3. 在Android Studio中配置
第二步,開始把pb模型移植到Android Studio上,網上絕大部分資料都是說用bazel重新編譯模型生成依賴,這種方法難度太大。其實沒必須這樣做,TensorFlow Lite官方的例子中已經給我們展示了,我們其實只需要兩個文件:libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。這兩個文件我已經放到github上了,大家可以自行下載使用,下載地址:libandroid_tensorflow_inference_java.jar、libtensorflow_inference.so。
注:檢神說,直接用aar依賴也可以,這個我沒試過。。有興趣的可以試一下。
準備工作已經完畢,下面正式開始Android Studio中的配置。
首先把訓練好的pb模型放到Android項目中app/src/main/assets下,若不存在assets目錄,則自己新建一個。如圖所示:
其次,把剛剛下載的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目錄下,如下圖所示:
然後在app/build.gradle裏進行如下配置:
在defaultConfig裏添加
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
在android裏添加
sourceSets {
main {
jni.srcDirs = []
jniLibs.srcDirs = ['libs']
}
}
如圖所示:
在dependencies中添加libandroid_tensorflow_inference_java.jar,即:
implementation files('libs/libandroid_tensorflow_inference_java.jar')
- 1
如圖所示:
至此,所有配置已經完成,下面是模型調用。
4. 在Android Studio中調用模型
在要用到模型的地方,首先要加載libtensorflow_inference.so庫和初始化TensorFlowInferenceInterface對象,代碼爲:
TensorFlowInferenceInterface inferenceInterface;
static {
//加載libtensorflow_inference.so庫文件
System.loadLibrary("tensorflow_inference");
Log.e("tensorflow","libtensorflow_inference.so庫加載成功");
}
Classifier(AssetManager assetManager, String modePath) {
//初始化TensorFlowInferenceInterface對象
inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
Log.e("tf","TensoFlow模型文件加載成功");
}
如圖所示:
下面來多看一點東西,看看TensorFlow Lite裏提供了哪幾個接口,官網地址:Here’s what a typical Inference Library sequence looks like on Android.
// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);
// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
// Run the inference call.
inferenceInterface.run(outputNames, logStats);
// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);
下面就可以愉快地使用模型了。放一段我調用模型的代碼,以供大家參考:
public ArrayList predict(Bitmap bitmap)
{
ArrayList<String> list = new ArrayList<>();
float[] inputdata = getPixels(bitmap);
for(int i = 0; i <30; ++i)
{
Log.d("matrix",inputdata[i] + "");
}
inferenceInterface.feed(inputName, inputdata, 1, IMAGE_SIZE, IMAGE_SIZE, 3);
//運行模型,run的參數必須是String[]類型
String[] outputNames = new String[]{outputName,probabilityName,outlayerName};
inferenceInterface.run(outputNames);
//獲取結果
int[] labels = new int[1];
inferenceInterface.fetch(outputName,labels);
int label = labels[0];
float[] prob = new float[11];
inferenceInterface.fetch(probabilityName, prob);
// float[] outlayer = new float[11];
// inferenceInterface.fetch(outlayerName, outlayer);
// for(int i = 0; i <11; ++i)
// {
// Log.d("matrix",outlayer[i] + "");
// }
for(int i = 0; i <11; ++i)
{
Log.d("matrix",prob[i] + "");
}
DecimalFormat df = new DecimalFormat("0.000000");
float label_prob = prob[label];
//返回值
list.add(Integer.toString(label));
list.add(df.format(label_prob));
return list;
}
最後放一張做的數字手勢識別APP的效果,全部代碼,將會開源在github上,歡迎star。
再放一張碰運氣的識別結果:
--------------------- 本文來自 天澤28 的CSDN 博客 ,全文地址請點擊:https://blog.csdn.net/u012328159/article/details/81101074?utm_source=copy