手把手教你開發人工智能微信小程序(4): 訓練手寫數字識別模型

在上篇文章《手把手教你開發人工智能微信小程序(3):加載數據》中,我給大家演示瞭如何通過fetch加載網絡數據並進行數據歸範化,出於演示的目的,例子做了簡化處理,本文中將給大家介紹一個稍微複雜一點的例子:手寫數字識別。很多機器學習的教程都以手寫數字識別作爲上手的示例,我在之前的文章也寫過幾篇:

可供參考。在本文中,我將演示如何訓練卷積神經網絡模型來識別手寫數字。

需要說明的是,不建議在微信小程序中訓練模型,而且通常的流程是模型訓練與模型使用分離,本文的示例在實用性上可能欠缺,僅僅是爲了給大家展示一種可能性,同時讓大家對整個機器學習的過程有所瞭解。閱讀完本文後,你將瞭解到:

  • 如何通過網絡加載圖片類型數據

  • 如何使用tfjs Layers API定義模型結構

  • 如何訓練模型以及評估模型

加載MNIST數據

針對手寫數字識別問題,網絡上已經有公開數據集MNIST。這是一套28x28大小手寫數字的灰度圖像,包含55000個訓練樣本,10000個測試樣本,另外還有5000個交叉驗證數據樣本。該數據集有多種格式,如果使用keras、tensorflow之類的python機器學習框架,通常有內置的API加載和處理MNIST數據集,但tensorflow.js並沒有提供,所以需要自己編寫。

常見的MNIST數據集是以多張通過目錄進行歸類的圖片集,比如手寫數字0的圖片都放到目錄名爲0的目錄下,手寫數字1的圖片都放到目錄名爲1的目錄下,依次類推,如下圖所示:

按目錄歸類的數據集

也有的數據集是將所有圖片放到一個目錄下,然後加上一個文本文件,描述每個文件對應的標籤:

csv文件

這種形式的數據集並不適合tfjs,因爲出於安全的考慮,js無法訪問本地文件,大量小的文件的網絡訪問效率很低。所以有人將65000個圖片合併爲一張圖片,但不是簡單的將65000個圖片拼接起來,而是將每個圖片的二進制像素線性展開,一張手寫數字圖片供784個像素,佔圖片中的一行,最後得到的圖像尺寸爲784 * 65000,最後形成的圖像對我們來說像是一張無意義的圖片:

拼接的MNIST圖片

加載MNIST圖像數據的代碼如下:

  async load(canvasId, imgWidth, imgHeight) {
    const ctx = wx.createCanvasContext(canvasId);
    
    const datasetBytesBuffer =
      new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);


    const chunkSize = 5000;


    let drawJobs = [];
    for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
      const datasetBytesView = new Float32Array(
        datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
        IMAGE_SIZE * chunkSize);
      ctx.drawImage(
        MNIST_IMAGES_SPRITE_PATH, 0, i * chunkSize, imgWidth, chunkSize, 0, 0, imgWidth,
        chunkSize);


      drawJobs.push(new Promise((resolve, reject) => {
        ctx.draw(false, () => {
          // API 1.9.0 獲取圖像數據
          wx.canvasGetImageData({
            canvasId: canvasId,
            x: 0,
            y: 0,
            width: imgWidth,
            height: chunkSize,
            success(imageData) {
              for (let j = 0; j < imageData.data.length / 4; j++) {
                // All channels hold an equal value since the image is grayscale, so
                // just read the red channel.
                datasetBytesView[j] = imageData.data[j * 4] / 255;
              }
              resolve();
            },
            fail: e => {
              console.error(e);
              resolve();
            },
          });
        });
      }));
    }
    await Promise.all(drawJobs);


    this.datasetImages = new Float32Array(datasetBytesBuffer);


    const fetch = fetchWechat.fetchFunc();
    const labelsResponse = await fetch(MNIST_LABELS_PATH);


    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());


    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);


    // Slice the the images and labels into train and test sets.
    this.trainImages =
      this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
      this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
      this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

這段代碼有幾點需要注意:

  1. 因爲送入模型訓練的是像素RGB數據,所以需要先對圖片進行解碼,提取每個手寫數字對應的784個像素值,在微信小程序中是藉助Canvas繪製圖像這種方式獲得,也許有其它更好的直接解碼的方法。

  2. 因爲canvasGetImageData是一個異步方法,所以代碼中使用了Promise異步模式,等待所有圖像數據獲取完畢。而圖像分部分繪製,也是避免大圖片繪製導致內存問題。

  3. 整個數據集拆分爲訓練數據集和測試數據集,訓練數據集包含55000個數據,測試數據集10000個數據。nextTrainBatch(batchSize)方法從訓練集中返回一組隨機圖像及其標籤。nextTestBatch(batchSize)方法從測試集中返回一批圖像及其標籤。

定義模型結構

關於卷積神經網絡,可以參閱《一步步提高手寫數字的識別率(3)》這篇文章,這裏定義的卷積網絡結構爲:

CONV -> MAXPOOlING -> CONV -> MAXPOOLING -> FC -> SOFTMAX

每個卷積層使用RELU激活函數,代碼如下:

function getModel() {
  const model = tf.sequential();


  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const IMAGE_CHANNELS = 1;


  // In the first layer of out convolutional neural network we have
  // to specify the input shape. Then we specify some paramaters for
  // the convolution operation that takes place in this layer.
  model.add(tf.layers.conv2d({
    inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
    kernelSize: 5,
    filters: 8,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));


  // The MaxPooling layer acts as a sort of downsampling using max values
  // in a region instead of averaging.
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));


  // Repeat another conv2d + maxPooling stack.
  // Note that we have more filters in the convolution.
  model.add(tf.layers.conv2d({
    kernelSize: 5,
    filters: 16,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));


  // Now we flatten the output from the 2D filters into a 1D vector to prepare
  // it for input into our last layer. This is common practice when feeding
  // higher dimensional data to a final classification output layer.
  model.add(tf.layers.flatten());


  // Our last layer is a dense layer which has 10 output units, one for each
  // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
  const NUM_OUTPUT_CLASSES = 10;
  model.add(tf.layers.dense({
    units: NUM_OUTPUT_CLASSES,
    kernelInitializer: 'varianceScaling',
    activation: 'softmax'
  }));




  // Choose an optimizer, loss function and accuracy metric,
  // then compile and return the model
  const optimizer = tf.train.adam();
  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });


  return model;
}

如果有過tensorflow python代碼編寫經驗,上面的代碼應該很容易理解。

訓練模型

在瀏覽器中訓練,也可以批量輸入圖像數據,可以指定batch size,epoch輪次。

  const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  const container = {
    name: 'Model Training', styles: { height: '1000px' }
  };
  // const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);


  const BATCH_SIZE = 512;
  const TRAIN_DATA_SIZE = 5500;
  const TEST_DATA_SIZE = 1000;


  const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
      d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });


  const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
      d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });


  return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 10,
    shuffle: true,
  });

tfvis庫在微信小程序中不能正常工作,所以無法像在瀏覽器中訓練那樣,可視化監控訓練過程。這個訓練過程比較長,我在微信開發者工具中通過模擬器大概需要半個小時,請耐心等待。

評估訓練的模型

評估時喂入測試集:

function doPrediction(model, data, testDataSize = 500) {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const testData = data.nextTestBatch(testDataSize);
  const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
  const labels = testData.labels.argMax([-1]);
  const preds = model.predict(testxs).argMax([-1]);


  testxs.dispose();
  return [preds, labels];
}

計算在測試集上的準確率,也就是統計預測值和真實值匹配的個數:

    const predsArray = preds.dataSync();
    const labelsArray = labels.dataSync();
    var n = 0;
    for (var i = 0; i < predsArray.length; i++) {
      console.log(predsArray[i]);
      console.log(labelsArray[i]);
      if (predsArray[i] == labelsArray[i])
        n++;
    }
    const accuracy = n / predsArray.length;
    console.log(accuracy);

小結

本文探討了如何從網絡加載MNIST數據集,定義卷積神經網絡模型,訓練模型及評估模型。這個簡單的例子,包含了機器學習的整個過程,雖然在實際中我們可能不會這樣用。在下篇文章中,我將介紹如何使用現有模型。如果你有什麼建議,歡迎留言。

本系列文章的源碼請訪問:

https://github.com/mogotech/wechat-tfjs-examples

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章