神經網絡的MLPC(多層感知器分類器)

pom

<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-core_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-streaming_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-mllib_2.12</artifactId>
        <version>2.4.0</version>
    </dependency>

    <dependency>
        <groupId>com.thoughtworks.paranamer</groupId>
        <artifactId>paranamer</artifactId>
        <version>2.8</version>
    </dependency>
</dependencies>
<!--打可執行jar包-->
<build>
    <plugins>
        <plugin>
            <groupId>org.apache.maven.plugins</groupId>
            <artifactId>maven-compiler-plugin</artifactId>
            <version>3.3</version>
            <configuration>
                <source>1.8</source>
                <target>1.8</target>
                <encoding>UTF-8</encoding>
            </configuration>
        </plugin>
    </plugins>
    <resources>
        <resource>
            <directory>src/main/resources</directory>
            <includes>
                <include>**/*.*</include>
            </includes>
        </resource>
    </resources>
</build>

 

 

代碼

 

import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.*;
import java.util.Arrays;

public class Internet {

    public static void main(String[] args) throws Exception {
        SparkSession spark = SparkSession
                .builder()
                .appName("internet")
                .master("local[1]")
                .getOrCreate();
        String path = "D:\\IdeaProjects\\SparkMLlib\\src\\test\\java\\data4";
        /**
         * 數據格式爲libsvm類型
         * 例如:
         * 0 1:1 2:1
         * 1 1:1 2:3
         * 標籤 屬性列:屬性值 屬性列:屬性值
         * 去/不去   1:天氣   2:溫度
         * 1晴天 2陰天 3雨天
         */
        //屏蔽日誌
        spark.sparkContext().setLogLevel("ERROR");
        //加載數據,randomSplit時加了一個固定的種子seed=100,
        //是爲了得到可重複的結果,方便調試算法,實際工作中不能這樣設置
        Dataset<Row>[] split = spark.read().format("libsvm").load(path).randomSplit(new double[]{0.9, 0.1}, 100);
        Dataset<Row> train = split[0];  //訓練集
        Dataset<Row> test = split[1];   //測試集
        train.show(10, false);//檢查部分數據
        int[] layer = new int[]{2, 8, 8, 8, 8, 2};//屬性數的輸入層/多個隱含層/標籤數輸出層
        MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
                .setLayers(layer)//神經網絡結構
                .setStepSize(10)//學習率
                .setMaxIter(300);//迭代學習次數
        MultilayerPerceptronClassificationModel model = trainer.fit(train);//訓練模型
        Dataset<Row> result = model.transform(test);//預測測試集合
        Dataset<Row> predictionAndLabels = result.select("prediction", "label");//篩選屬性
        //校驗是否相等,計算相等的機率
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy");
        System.out.println("測試正確率 = " + evaluator.evaluate(predictionAndLabels));
        result.show(5, false);//顯示部分測試結果
        spark.close();
        byte[] bytes = toBytes(model);
        System.out.println("算法長度爲:" + bytes.length);
        model = (MultilayerPerceptronClassificationModel) toClass(bytes);//反序列化強轉
        //預測新的數據
        Vector denseVector = new DenseVector(new double[]{1, 88});
        double predict = model.predict(denseVector);//使用的是ml的包,不是mllib的
        System.out.println("反序列化預測:" + predict);
    }

    /**
     * Title: toBytes
     * Description:序列化對象
     *
     * @throws Exception
     * @author zhengzx
     */
    public static byte[] toBytes(Object out) throws Exception {
        //用於序列化後存儲對象
        ByteArrayOutputStream byteArrayOutputStream;
        //java序列化API
        ObjectOutputStream objectOutputStream = null;
        try {
            byteArrayOutputStream = new ByteArrayOutputStream();
            objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
            //將out對象進行序列化
            objectOutputStream.writeObject(out);
            //測試驗證輸入(獲取字節數組)
            byte[] bs = byteArrayOutputStream.toByteArray();
            //將數組轉化爲字符串
            System.out.println(Arrays.toString(bs));
            //todo 保存數組
            return bs;
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            //關閉最外層的流(內部流會自動關閉)
            objectOutputStream.close();
        }
        return null;
    }

    public static Object toClass(byte[] bs) throws Exception {
        //創建存放二進制數據的API
        ByteArrayInputStream byteArrayInputStream;
        //創建反序列化對象
        ObjectInputStream objectInputStream = null;
        try {
            byteArrayInputStream = new ByteArrayInputStream(bs);
            objectInputStream = new ObjectInputStream(byteArrayInputStream);
            //校驗測試
            Object obj = objectInputStream.readObject();
            return obj;
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            objectInputStream.close();
        }
        return null;
    }

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章