Tensorflow手寫數字識別在android中的實現

說明

下載TensorFlow Android Demo

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

生成模型

運行附件壓縮包裏的python腳本convnet.py生成mnist_model_graph_convnet.pb文件和graph_label_strings.txt文件:
文件

編譯jar包和so庫

1. 下載TensorFlow Android Demo
git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

備註:

--recurse-submodules 是爲了避免一些protobuf 編譯問題.

2. 修改WORKSPACE文件,指定SDK、NDK的版本和路徑,請務必使用NDK r12b,下載路徑爲:
https://developer.android.com/ndk/downloads/older_releases.html  #ndk-12b-downloads

例如,我是這樣配置的:

android_sdk_repository(
    name = "androidsdk",
    api_level = 25,
    # Ensure that you have the build_tools_version below installed in the
    # SDK manager as it updates periodically.
    build_tools_version = "25.0.3",
    # Replace with path to Android SDK on your system
    path = "/home/ckt/work/Android/Sdk",
)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
android_ndk_repository(
    name="androidndk",
    path="/home/ckt/work/Android/ndk-r12b/",
    # This needs to be 14 or higher to compile TensorFlow.
    # Note that the NDK version is not the API level.
    api_level=14)

3. 編譯jar包和so庫
編譯jar包和so庫需要構建工具Bazel,Ubuntu環境下如何安裝Bazel請參考網頁:

https://bazel.build/versions/master/docs/install-ubuntu.html

編譯jar包命令:

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java


編譯完成後,可以在以下路徑找到libandroid_tensorflow_inference_java.jar文件:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

編譯so庫命令:

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
  --crosstool_top=//external:android/crosstool \
  --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
  --cpu=armeabi-v7a 

###cpu一定要適配自己的手機,否則找不到so文件###


編譯完成後,可以在以下路徑找到libtensorflow_inference.so文件:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

編寫應用

1. 打開Android Studio,新建一個android工程
將jar包放入libs目錄,將so庫放入src/main/jniLibs/armeabi-v7a目錄,將之前生成的pb文件和text文件放入src/main/assets目錄

2. 將TensorFlow Android Demo中的Classifier.java和TensorFlowImageClassifier.java複製到工程,這2個文件在TensorFlow Android Demo中的的路徑爲:

/tensorflow/examples/android/src/org/tensorflow/demo

注意:
需要將這2個類的包名修改爲自己工程的包名。

3.爲了簡便操作,我們將下面的mnist_test.png(一張灰度圖,28×28像素,白字黑底)放到src/main/assets目錄下


備註:
IMAGE_MEAN和IMAGE_STD的值在本項目沒有實際意義,可以隨便設置。

4.在activity中調用TensorFlowImageClassifier.create()方法創建分類器:


5. 將mnist_test.png圖片轉換成相應的bitmap(28x28),通過classifier.recognizeImage(bitmap)來取得預測結果


注意:
因爲我們的輸入數據是28x28的灰度圖,原始代碼用到了rgb三個通道,我們只需要一個通道,所以需要修改TensorFlowImageClassifier類的recognizeImage方法來適應模型,代碼如下:



 bitmapToFloatArray()方法如下:
 /**
  * 將bitmap轉爲(按行優先)一個float數組。其中的每個像素點都歸一化到0~1之間。
  * @param bitmap 灰度圖,r,g,b分量都相等。
  * @return
  */
 public static float[] bitmapToFloatArray(Bitmap bitmap){
   int height = bitmap.getHeight();
   int width = bitmap.getWidth();
   float[] result = new float[height * width];

   int k = 0;
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int argb = bitmap.getPixel(i, j);
       // 由於是灰度圖,所以r,g,b分量是相等的。
       int r = Color.red(argb);
       result[k++] = r / 255.0f;
     }
   }
   return result;
 }

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章