將TensorFlow訓練好的模型遷移到Android APP上(TensorFlowLite)

將TensorFlow訓練好的模型遷移到Android APP上(TensorFlowLite)

1. 寫在前面


  最近在做一個數字手勢識別的APP(關於這個項目,我會再寫一篇博客仔細介紹,博客地址:一步步做一個數字手勢識別APP,源代碼已經開源在github上,地址:Chinese-number-gestures-recognition),要把在PC端訓練好的模型放到Android APP上,調研了下,谷歌發佈了TensorFlow Lite可以把TensorFlow訓練好的模型遷移到Android APP上,百度也發佈了移動端深度學習框架mobile-deep-learning(MDL),這個框架應該是paddlepaddle的手機版,具體的細節沒有了解過。因爲對TensorFlow稍微熟悉些,因此就決定用TensorFlow來做。
  關於在PC端如何處理數據及訓練模型,請參見博客:一步步做一個數字手勢識別APP,代碼已經開源在github上,上面有代碼的說明和APP演示。這篇博客只介紹如何把TensorFlow訓練好的模型遷移到Android Studio上進行APP的開發。

2. 模型訓練注意事項


  第一步,首先在pc端訓練模型的時候要模型保存爲.pb模型,在保存的時候有一點非常非常重要,就是你待會再Android studio是使用這個模型用到哪個參數,那麼你在保存pb模型的時候就把給哪個參數一個名字,再保存。否則,你在Android studio中很難拿出這個參數,因爲TensorFlow Lite的fetch()函數是根據保存在pb模型中的名字去尋找這個參數的。(如果你已經訓練好了模型,並且沒有給參數名字,且你不想再訓練模型了,那麼你可以嘗試下面的方法去找到你需要使用的變量的默認名字,見下面的代碼):

#輸出保存的模型中參數名字及對應的值
with tf.gfile.GFile('model_50_200_c3//./digital_gesture.pb', "rb") as f:  #讀取模型數據
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) #得到模型中的計算圖和數據
with tf.Graph().as_default() as graph:  # 這裏的Graph()要有括號,不然會報TypeError
    tf.import_graph_def(graph_def, name="")  #導入模型中的圖到現在這個新的計算圖中,不指定名字的話默認是 import
    for op in graph.get_operations():  # 打印出圖中的節點信息
        print(op.name, op.values())

這段代碼打出的變量的名字以及對應的值。

言歸正傳,通常情況該你應該保存參數的時候都給參數一個指定的名字,如下面這樣(通過name參數給變量指定名字),關於訓練CNN的完整代碼請參見下一篇博客或者github:

X = tf.placeholder(tf.float32, [None, 64, 64, 3], name="input_x")
y = tf.placeholder(tf.float32, [None, 11], name="input_y")
kp = tf.placeholder_with_default(1.0, shape=(), name="keep_prob")
lam = tf.placeholder(tf.float32, name="lamda")
#中間略過若干代碼
z_fc2 = tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")
prob = tf.nn.softmax(z_fc2, name="probability")
pred = tf.argmax(prob, 1, output_type="int32", name="predict")

3. 在Android Studio中配置


  第二步,開始把pb模型移植到Android Studio上,網上絕大部分資料都是說用bazel重新編譯模型生成依賴,這種方法難度太大。其實沒必須這樣做,TensorFlow Lite官方的例子中已經給我們展示了,我們其實只需要兩個文件:libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。這兩個文件我已經放到github上了,大家可以自行下載使用,下載地址:libandroid_tensorflow_inference_java.jarlibtensorflow_inference.so

注:檢神說,直接用aar依賴也可以,這個我沒試過。。有興趣的可以試一下。

準備工作已經完畢,下面正式開始Android Studio中的配置。

  首先把訓練好的pb模型放到Android項目中app/src/main/assets下,若不存在assets目錄,則自己新建一個。如圖所示:

pb模型目錄

 

  其次,把剛剛下載的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目錄下,如下圖所示:

TensorFlow依賴目錄

 

然後在app/build.gradle裏進行如下配置:
  在defaultConfig裏添加

multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }

  在android裏添加

 sourceSets {
        main {
            jni.srcDirs = []
            jniLibs.srcDirs = ['libs']
        }
    }

如圖所示:

配置文件

 

  在dependencies中添加libandroid_tensorflow_inference_java.jar,即:

implementation files('libs/libandroid_tensorflow_inference_java.jar')
  • 1

如圖所示:

dependencies

 

至此,所有配置已經完成,下面是模型調用。

4. 在Android Studio中調用模型


在要用到模型的地方,首先要加載libtensorflow_inference.so庫和初始化TensorFlowInferenceInterface對象,代碼爲:

TensorFlowInferenceInterface inferenceInterface;

    static {
        //加載libtensorflow_inference.so庫文件
        System.loadLibrary("tensorflow_inference");
        Log.e("tensorflow","libtensorflow_inference.so庫加載成功");
    }
    Classifier(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface對象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e("tf","TensoFlow模型文件加載成功");
    }

如圖所示:

加載TensorFlow依賴庫

 

下面來多看一點東西,看看TensorFlow Lite裏提供了哪幾個接口,官網地址:Here’s what a typical Inference Library sequence looks like on Android.

// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);

// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);

// Run the inference call.
inferenceInterface.run(outputNames, logStats);

// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);

下面就可以愉快地使用模型了。放一段我調用模型的代碼,以供大家參考:

public ArrayList predict(Bitmap bitmap)
    {
        ArrayList<String> list = new ArrayList<>();
        float[] inputdata = getPixels(bitmap);
        for(int i = 0; i <30; ++i)
        {
            Log.d("matrix",inputdata[i] + "");
        }
        inferenceInterface.feed(inputName, inputdata, 1, IMAGE_SIZE, IMAGE_SIZE, 3);
        //運行模型,run的參數必須是String[]類型
        String[] outputNames = new String[]{outputName,probabilityName,outlayerName};
        inferenceInterface.run(outputNames);
        //獲取結果
        int[] labels = new int[1];
        inferenceInterface.fetch(outputName,labels);
        int label = labels[0];
        float[] prob = new float[11];
        inferenceInterface.fetch(probabilityName, prob);
//        float[] outlayer = new float[11];
//        inferenceInterface.fetch(outlayerName, outlayer);

//        for(int i = 0; i <11; ++i)
//        {
//            Log.d("matrix",outlayer[i] + "");
//        }
        for(int i = 0; i <11; ++i)
        {
            Log.d("matrix",prob[i] + "");
        }
        DecimalFormat df = new DecimalFormat("0.000000");
        float label_prob = prob[label];
        //返回值
        list.add(Integer.toString(label));
        list.add(df.format(label_prob));

        return list;
    }

最後放一張做的數字手勢識別APP的效果,全部代碼,將會開源在github上,歡迎star。

識別結果

 

再放一張碰運氣的識別結果:

碰巧識別

--------------------- 本文來自 天澤28 的CSDN 博客 ,全文地址請點擊:https://blog.csdn.net/u012328159/article/details/81101074?utm_source=copy

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章