基於tensorflow實現Android手寫數字識別

前段時間訓練了mnist手寫數字識別的模型,學習後將其移植到Android端
我是參考的大佬https://puke3615.github.io/2017/08/02/Run-Mnist-On-Android/https://github.com/wangtianrui/TFonAndroid的源碼,有需要的的朋友可以去下載,這裏是對他寫的代碼的分析和我自己的理解
註解ButterKnife學習:https://www.jianshu.com/p/952c6f5e8157

implementation 'com.jakewharton:butterknife:8.8.1'

手機上效果爲:
在這裏插入圖片描述
在這裏插入圖片描述

移植到Android時要添加依賴文件:libandroid_tensorflow_inference_java.jar,和編譯後的TensoFlow的so庫,libtensorflow_inference.so,將其添加在lib文件夾中:
在這裏插入圖片描述

接下來將訓練好的pb模型放入assets文件夾中
在這裏插入圖片描述

在build.gradle文件中添加:這個可以支持在手機中調試

testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"

這裏分享一下運行程序時遇到的坑:
出現了問題“Android-Device supports x86,but APK only supports armeabi-v7a,armeabi,x86_64”,使用模擬器不能運行,因爲之前添加了支持tensorflow的so庫和jar包,後來我在build文件中添加

   multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }

在這裏插入圖片描述
也還是沒有用,後來看到大佬的代碼彷彿才明白了一些東西。。。
在build中增加:

    sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }

下面是實現過程:

定義mnist分類器:

private final String MODEL_PATH = "file:///android_asset/mnist.pb";//加載模型
    public static final String INPUT_NAME = "input";//對應訓練模型佔位符x-input
    public static final String KEEP_PROB_NAME = "keep_prob";
    public static final String OUTPUT_NAME = "output";//訓練模型時的佔位符y_
//注意訓練模型時一定要對值和標籤設置name值,導入Android後要喂數據
//tensorflow依賴文件的類
    private TensorFlowInferenceInterface inference;

//圖片像素28*28
    private final int width = 28;
    private final int heifht = 28;
    private float[] inputs = new float[width * heifht];
    private int[] INPUT_SHAPE = new int[]{1, width * heifht};

//AssetManager :提供低級別的訪問應用資源的API
//不同模型框架,訓練模型輸入的佔位符不同,一定要一一對應;
//訓練的數據集一定要對齊,resize與採用模型框架圖像的大小一致,Android端調用接口,輸入參數一定要和訓練圖像一致,否則會出現分類錯誤。
    public MnistClassifier(AssetManager assetManager) {
        this.inference = new TensorFlowInferenceInterface(assetManager, MODEL_PATH);//傳入模型的路徑
        //模型使用階段, 不需要進行dropout處理, 所以keep_prob直接爲1.0
        //dropout層:keep_prob訓練時爲0.5,測試時爲1
        inference.feed(KEEP_PROB_NAME, new float[]{1.0f}, 1);
    }

    public float[] getResult(float[] inputs) {
        try {
            this.inputs = inputs;
        } catch (Exception e) {
            e.printStackTrace();
        }
        //輸出結果是十個數字的概率
        float[] output = new float[10];
//填入Input數據
        inference.feed(INPUT_NAME, inputs, 1, width * heifht);
        //運行結果, 類似Python中的sess.run([outputs])
        inference.run(new String[]{OUTPUT_NAME}, false);
        inference.fetch(OUTPUT_NAME, output);
        return output;
    }

定義畫板:

public class PrinterView extends View {

    //畫筆
    private Paint paint;

    //用來存儲“路徑”
    private Path path;

    //屏幕寬
    private int width;

    public PrinterView(Context context) {
        super(context);
    }

    public PrinterView(Context context, @Nullable AttributeSet attrs) {
        super(context, attrs);
        setBackgroundColor(Color.WHITE);
        paint = new Paint();
        paint.setColor(Color.RED);
        paint.setStrokeWidth(TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, 20, getResources().getDisplayMetrics()));
        paint.setStyle(Paint.Style.STROKE);
        path = new Path();

        int screenWidth = getResources().getDisplayMetrics().widthPixels;
        width = MeasureSpec.makeMeasureSpec(screenWidth, MeasureSpec.EXACTLY);
    }
    /**
    1.精確模式(MeasureSpec.EXACTLY)

在這種模式下,尺寸的值是多少,那麼這個組件的長或寬就是多少。

2.最大模式(MeasureSpec.AT_MOST)

這個也就是父組件,能夠給出的最大的空間,當前組件的長或寬最大隻能爲這麼大,當然也可以比這個小。

3.未指定模式(MeasureSpec.UNSPECIFIED)

這個就是說,當前組件,可以隨便用空間,不受限制。
    */

//畫出手指滑動的軌跡
    @Override
    public boolean onTouchEvent(MotionEvent event) {
        float x = event.getX();
        float y = event.getY();
        switch (event.getAction()) {
            case MotionEvent.ACTION_DOWN:
                //按下
                path.moveTo(x, y);
                break;
            case MotionEvent.ACTION_MOVE:
                path.lineTo(x, y);
                break;
        }
        //刷新view
        invalidate();
        return true;
    }
    @Override
    protected void onDraw(Canvas canvas) {
        super.onDraw(canvas);
        canvas.drawPath(path, paint);
    }

    @Override
    protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
        //定製畫板的寬和高
        super.onMeasure(width, width);
    }
    public void clean() {
        path.reset();
        invalidate();
    }
    public boolean isEmpty() {
        return path.isEmpty();
    }
//向外部提供讀取畫布數據的方法
/**
View組件顯示的內容可以通過cache機制保存爲bitmap, 使用到的api有
    void  setDrawingCacheEnabled(boolean flag),
    Bitmap  getDrawingCache(boolean autoScale),
    void  buildDrawingCache(boolean autoScale),
    void  destroyDrawingCache()
    我們要獲取它的cache先要通過setDrawingCacheEnable方法把cache開啓,然後再調用getDrawingCache方法就可 以獲得view的cache圖片了。buildDrawingCache方法可以不用調用,因爲調用getDrawingCache方法時,若果 cache沒有建立,系統會自動調用buildDrawingCache方法生成cache。若果要更新cache, 必須要調用destoryDrawingCache方法把舊的cache銷燬,才能建立新的。
當調用setDrawingCacheEnabled方法設置爲false, 系統也會自動把原來的cache銷燬。
   ViewGroup在繪製子view時,而外提供了兩個方法
   void setChildrenDrawingCacheEnabled(boolean enabled)
   setChildrenDrawnWithCacheEnabled(boolean enabled)
   setChildrenDrawingCacheEnabled方法可以使viewgroup裏所有的子view開啓cache, setChildrenDrawnWithCacheEnabled使在繪製子view時,若該子view開啓了cache, 則使用它的cache進行繪製,從而節省繪製時間。
   獲取cache通常會佔用一定的內存,所以通常不需要的時候有必要對其進行清理,通過destroyDrawingCache或setDrawingCacheEnabled(false)實現。
*/
    public float[] getData(int width, int height) {
        float[] data = new float[height * width];
        try {
            //先讓cache可以被讀取(將View轉化爲圖片都會使用cache)
            setDrawingCacheEnabled(true);
            setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);
            Bitmap cache = getDrawingCache();
            dealData(cache, data, width, height);
        } finally {
            setDrawingCacheEnabled(false);
        }
        return data;
    }

    private void dealData(Bitmap bm, float[] data, int newWidth, int newHeight) {
        //獲得bitmap的寬和高
        int width = bm.getWidth();
        int height = bm.getHeight();

        //計算縮放比例
        float scaleWidth = ((float) newWidth) / width;
        float scaleHeight = ((float) newHeight) / height;
//取得想要縮放的matrix參數
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        //獲得目標大小的圖
        Bitmap newBm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);

        for (int y = 0; y < newHeight; y++) {
            for (int x = 0; x < newWidth; x++) {
                //獲得每個點的像素值
                int pixel = newBm.getPixel(x, y);
                data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;
            }
        }

    }
}

識別邏輯:


    @OnClick({R.id.printer_view, R.id.result_text_view, R.id.clean_button, R.id.detect_button})
    public void onViewClicked(View view) {
        switch (view.getId()) {
            case R.id.clean_button:
                printerView.clean();
                resultTextView.setText(null);
                break;
            case R.id.detect_button:
                if (printerView.isEmpty()) {
                    resultTextView.setText("畫板爲空");
                    break;
                }
                MnistClassifier mnistClassifier = new MnistClassifier(getAssets());
                float[] result = mnistClassifier.getResult(printerView.getData(28, 28));
                List<MnistItem> items = new ArrayList<>(10);
                for (int i = 0; i < result.length; i++) {
                    items.add(new MnistItem(result[i], i));
                }
                Collections.sort(items);//選擇概率最大的對應的值
                StringBuilder builder = new StringBuilder();
                for (int i = 0; i < 1 ; i++) {
                    MnistItem item = items.get(i);
                    builder.append((int)item.getIndex());
                }
                resultTextView.setText(builder.toString());
                break;
        }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章