瀏覽器端的機器學習 tensorflowjs(6) 訓練模型

現在模型已經定義好了,數據也下載並進行了處理,一切準備就緒準備開始訓練。

async function trainModel(model, inputs, labels) {
  // 準備要訓練的模型
  model.compile({
    optimizer: tf.train.adam(),
    loss: tf.losses.meanSquaredError,
    metrics: ['mse'],
  });

  const batchSize = 32;
  const epochs = 50;

  return await model.fit(inputs, labels, {
    batchSize,
    epochs,
    shuffle: true,
    callbacks: tfvis.show.fitCallbacks(
      { name: 'Training Performance' },
      ['loss', 'mse'],
      { height: 200, callbacks: ['onEpochEnd'] }
    )
  });
}

訓練前的一些準備

model.compile({
  optimizer: tf.train.adam(),
  loss: tf.losses.meanSquaredError,
  metrics: ['mse'],
});

在訓練模型之前,需要 "編譯 "該模型,那麼具體應該如何做呢? 我們需要一個優化和一個損失函數,損失函數也可以理解目標函數,主要是指定訓練,讓我們訓練一個目標,優化器這是給出一個策略如何在訓練過程更新參數。

  • 優化器。這是一種算法,是更新參數的算法。在 TensorFlow.js 中有許多優化器可用。這裏選擇了 adam 優化器,也可以嘗試用其他優化器
  • 損失函數:其實就是一個函數,告訴模型在學習過程中,在每個批次(數據子集)時的表現如何。這裏選擇 meanSquaredError 來比較模型的預測和真實值
const batchSize = 32;
const epochs = 50;

設置超參數 batchSize 和一個 epochs 的數量。

  • batchSize 指的是模型在每次迭代訓練中看到的數據子集的大小。常見的批次大小往往在 32-512 之間取值。批次大小對於訓練速度是有所影響的

  • epochs 完成整個數據集進行訓練的次數

開始訓練

return await model.fit(inputs, labels, {
  batchSize,
  epochs,
  callbacks: tfvis.show.fitCallbacks(
    { name: 'Training Performance' },
    ['loss', 'mse'],
    { height: 200, callbacks: ['onEpochEnd'] }
  )
});

model.fit 是來啓動訓練的函數。這是一個異步函數,所以返回會是一個 promise。

爲了監控訓練進度,回調傳函數作爲 model.fit 來獲取訓練過程中信息。然後回調函數使用 tfvis.show.fitCallbacks 來定義,然後可以繪製損失值對於迭代的圖標

const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;

// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');

這的注意的這部分代碼要寫在 run 函數中,具體如下

async function run() {
    // 加載數據
    const data = await getData();
    // 處理原始數據,將數據 horsepower 映射爲 x 而 mpg 則映射爲 y
    const values = data.map(d => ({
      x: d.horsepower,
      y: d.mpg,
    }));
    // 將數據以散點圖形式顯示在開發者調試工具
    
  
    tfvis.render.scatterplot(
      {name: 'Horsepower v MPG'},
      {values},
      {
        xLabel: 'Horsepower',
        yLabel: 'MPG',
        height: 300
      }
    );

    const model = createModel();
    const tensorData = convertToTensor(data);
    const {inputs, labels} = tensorData;

    // Train the model
    await trainModel(model, inputs, labels);
    console.log('Done Training');
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章