前段時間訓練了mnist手寫數字識別的模型,學習後將其移植到Android端
我是參考的大佬https://puke3615.github.io/2017/08/02/Run-Mnist-On-Android/,https://github.com/wangtianrui/TFonAndroid的源碼,有需要的的朋友可以去下載,這裏是對他寫的代碼的分析和我自己的理解
註解ButterKnife學習:https://www.jianshu.com/p/952c6f5e8157
implementation 'com.jakewharton:butterknife:8.8.1'
手機上效果爲:
移植到Android時要添加依賴文件:libandroid_tensorflow_inference_java.jar,和編譯後的TensoFlow的so庫,libtensorflow_inference.so,將其添加在lib文件夾中:
接下來將訓練好的pb模型放入assets文件夾中
在build.gradle文件中添加:這個可以支持在手機中調試
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
這裏分享一下運行程序時遇到的坑:
出現了問題“Android-Device supports x86,but APK only supports armeabi-v7a,armeabi,x86_64”,使用模擬器不能運行,因爲之前添加了支持tensorflow的so庫和jar包,後來我在build文件中添加
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
也還是沒有用,後來看到大佬的代碼彷彿才明白了一些東西。。。
在build中增加:
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
下面是實現過程:
定義mnist分類器:
private final String MODEL_PATH = "file:///android_asset/mnist.pb";//加載模型
public static final String INPUT_NAME = "input";//對應訓練模型佔位符x-input
public static final String KEEP_PROB_NAME = "keep_prob";
public static final String OUTPUT_NAME = "output";//訓練模型時的佔位符y_
//注意訓練模型時一定要對值和標籤設置name值,導入Android後要喂數據
//tensorflow依賴文件的類
private TensorFlowInferenceInterface inference;
//圖片像素28*28
private final int width = 28;
private final int heifht = 28;
private float[] inputs = new float[width * heifht];
private int[] INPUT_SHAPE = new int[]{1, width * heifht};
//AssetManager :提供低級別的訪問應用資源的API
//不同模型框架,訓練模型輸入的佔位符不同,一定要一一對應;
//訓練的數據集一定要對齊,resize與採用模型框架圖像的大小一致,Android端調用接口,輸入參數一定要和訓練圖像一致,否則會出現分類錯誤。
public MnistClassifier(AssetManager assetManager) {
this.inference = new TensorFlowInferenceInterface(assetManager, MODEL_PATH);//傳入模型的路徑
//模型使用階段, 不需要進行dropout處理, 所以keep_prob直接爲1.0
//dropout層:keep_prob訓練時爲0.5,測試時爲1
inference.feed(KEEP_PROB_NAME, new float[]{1.0f}, 1);
}
public float[] getResult(float[] inputs) {
try {
this.inputs = inputs;
} catch (Exception e) {
e.printStackTrace();
}
//輸出結果是十個數字的概率
float[] output = new float[10];
//填入Input數據
inference.feed(INPUT_NAME, inputs, 1, width * heifht);
//運行結果, 類似Python中的sess.run([outputs])
inference.run(new String[]{OUTPUT_NAME}, false);
inference.fetch(OUTPUT_NAME, output);
return output;
}
定義畫板:
public class PrinterView extends View {
//畫筆
private Paint paint;
//用來存儲“路徑”
private Path path;
//屏幕寬
private int width;
public PrinterView(Context context) {
super(context);
}
public PrinterView(Context context, @Nullable AttributeSet attrs) {
super(context, attrs);
setBackgroundColor(Color.WHITE);
paint = new Paint();
paint.setColor(Color.RED);
paint.setStrokeWidth(TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, 20, getResources().getDisplayMetrics()));
paint.setStyle(Paint.Style.STROKE);
path = new Path();
int screenWidth = getResources().getDisplayMetrics().widthPixels;
width = MeasureSpec.makeMeasureSpec(screenWidth, MeasureSpec.EXACTLY);
}
/**
1.精確模式(MeasureSpec.EXACTLY)
在這種模式下,尺寸的值是多少,那麼這個組件的長或寬就是多少。
2.最大模式(MeasureSpec.AT_MOST)
這個也就是父組件,能夠給出的最大的空間,當前組件的長或寬最大隻能爲這麼大,當然也可以比這個小。
3.未指定模式(MeasureSpec.UNSPECIFIED)
這個就是說,當前組件,可以隨便用空間,不受限制。
*/
//畫出手指滑動的軌跡
@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
//按下
path.moveTo(x, y);
break;
case MotionEvent.ACTION_MOVE:
path.lineTo(x, y);
break;
}
//刷新view
invalidate();
return true;
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
canvas.drawPath(path, paint);
}
@Override
protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
//定製畫板的寬和高
super.onMeasure(width, width);
}
public void clean() {
path.reset();
invalidate();
}
public boolean isEmpty() {
return path.isEmpty();
}
//向外部提供讀取畫布數據的方法
/**
View組件顯示的內容可以通過cache機制保存爲bitmap, 使用到的api有
void setDrawingCacheEnabled(boolean flag),
Bitmap getDrawingCache(boolean autoScale),
void buildDrawingCache(boolean autoScale),
void destroyDrawingCache()
我們要獲取它的cache先要通過setDrawingCacheEnable方法把cache開啓,然後再調用getDrawingCache方法就可 以獲得view的cache圖片了。buildDrawingCache方法可以不用調用,因爲調用getDrawingCache方法時,若果 cache沒有建立,系統會自動調用buildDrawingCache方法生成cache。若果要更新cache, 必須要調用destoryDrawingCache方法把舊的cache銷燬,才能建立新的。
當調用setDrawingCacheEnabled方法設置爲false, 系統也會自動把原來的cache銷燬。
ViewGroup在繪製子view時,而外提供了兩個方法
void setChildrenDrawingCacheEnabled(boolean enabled)
setChildrenDrawnWithCacheEnabled(boolean enabled)
setChildrenDrawingCacheEnabled方法可以使viewgroup裏所有的子view開啓cache, setChildrenDrawnWithCacheEnabled使在繪製子view時,若該子view開啓了cache, 則使用它的cache進行繪製,從而節省繪製時間。
獲取cache通常會佔用一定的內存,所以通常不需要的時候有必要對其進行清理,通過destroyDrawingCache或setDrawingCacheEnabled(false)實現。
*/
public float[] getData(int width, int height) {
float[] data = new float[height * width];
try {
//先讓cache可以被讀取(將View轉化爲圖片都會使用cache)
setDrawingCacheEnabled(true);
setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);
Bitmap cache = getDrawingCache();
dealData(cache, data, width, height);
} finally {
setDrawingCacheEnabled(false);
}
return data;
}
private void dealData(Bitmap bm, float[] data, int newWidth, int newHeight) {
//獲得bitmap的寬和高
int width = bm.getWidth();
int height = bm.getHeight();
//計算縮放比例
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
//取得想要縮放的matrix參數
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
//獲得目標大小的圖
Bitmap newBm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);
for (int y = 0; y < newHeight; y++) {
for (int x = 0; x < newWidth; x++) {
//獲得每個點的像素值
int pixel = newBm.getPixel(x, y);
data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;
}
}
}
}
識別邏輯:
@OnClick({R.id.printer_view, R.id.result_text_view, R.id.clean_button, R.id.detect_button})
public void onViewClicked(View view) {
switch (view.getId()) {
case R.id.clean_button:
printerView.clean();
resultTextView.setText(null);
break;
case R.id.detect_button:
if (printerView.isEmpty()) {
resultTextView.setText("畫板爲空");
break;
}
MnistClassifier mnistClassifier = new MnistClassifier(getAssets());
float[] result = mnistClassifier.getResult(printerView.getData(28, 28));
List<MnistItem> items = new ArrayList<>(10);
for (int i = 0; i < result.length; i++) {
items.add(new MnistItem(result[i], i));
}
Collections.sort(items);//選擇概率最大的對應的值
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 1 ; i++) {
MnistItem item = items.get(i);
builder.append((int)item.getIndex());
}
resultTextView.setText(builder.toString());
break;
}