JavaScript之机器学习4:Tensorflow.js 多层神经网络

推荐学习网站:http://playground.tensorflow.org/,这个网站就是用Tensorflow.js写出来的;

多层神经网络:XOR逻辑回归

同为0,异为1
在这里插入图片描述
操作步骤:

  1. 加载XOR数据集
  2. 定义模型结构:多层神经网络
    • 初始化一个神经网络模型
    • 为神经网络模型添加两个层
    • 设计层的神经元个数,inputShape,激活函数
  3. 训练模型并预测
    在这里插入图片描述
    在这里插入图片描述
    演示代码:
<!-- index.html  -->
<form action="" onsubmit="predict(this);return false;">
    x: <input type="text" name="x">
    y: <input type="text" name="y">
    <button type="submit">预测</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async() => {
    const data = getData(400); // 获取400个点

    tfvis.render.scatterplot(
        {name:'XOR 训练数据'},
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0)
            ]
        }
    );

    const model = tf.sequential();
    // 设置隐藏层
    model.add(tf.layers.dense({
        units: 4, // 神经元个数为4
        inputShape: [2], // 长度为2的一维数组,数据特征为2:x,y
        activation: 'relu' // 激活函数 非线性
    }));
    // 设置输出层
    model.add(tf.layers.dense({
        units:1, // 只需要输出一个概率
        activation: 'sigmoid' // 输出0-1之间的概率
    }));
    
    // 设置损失函数和优化器
    model.compile({
        loss:tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x,p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs,labels,{
        epochs:10,
        callbacks:tfvis.show.fitCallbacks(
            {name:'训练过程'},
            ['loss']
        )
    });

    window.predict = async (form) => {
        const pred = await model.predict(tf.tensor([[form.x.value*1,form.y.value*1]]))
        alert(`预测结果:${pred.dataSync()[0]}`);
    }

}
// data.js
export function getData(numSamples) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 0);
    genGauss(-2, -2, 0);
    genGauss(-2, 2, 1);
    genGauss(2, -2, 1);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章