pom
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>com.thoughtworks.paranamer</groupId>
<artifactId>paranamer</artifactId>
<version>2.8</version>
</dependency>
</dependencies>
<!--打可執行jar包-->
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
</plugins>
<resources>
<resource>
<directory>src/main/resources</directory>
<includes>
<include>**/*.*</include>
</includes>
</resource>
</resources>
</build>
代碼
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.*;
import java.util.Arrays;
public class Internet {
public static void main(String[] args) throws Exception {
SparkSession spark = SparkSession
.builder()
.appName("internet")
.master("local[1]")
.getOrCreate();
String path = "D:\\IdeaProjects\\SparkMLlib\\src\\test\\java\\data4";
/**
* 數據格式爲libsvm類型
* 例如:
* 0 1:1 2:1
* 1 1:1 2:3
* 標籤 屬性列:屬性值 屬性列:屬性值
* 去/不去 1:天氣 2:溫度
* 1晴天 2陰天 3雨天
*/
//屏蔽日誌
spark.sparkContext().setLogLevel("ERROR");
//加載數據,randomSplit時加了一個固定的種子seed=100,
//是爲了得到可重複的結果,方便調試算法,實際工作中不能這樣設置
Dataset<Row>[] split = spark.read().format("libsvm").load(path).randomSplit(new double[]{0.9, 0.1}, 100);
Dataset<Row> train = split[0]; //訓練集
Dataset<Row> test = split[1]; //測試集
train.show(10, false);//檢查部分數據
int[] layer = new int[]{2, 8, 8, 8, 8, 2};//屬性數的輸入層/多個隱含層/標籤數輸出層
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
.setLayers(layer)//神經網絡結構
.setStepSize(10)//學習率
.setMaxIter(300);//迭代學習次數
MultilayerPerceptronClassificationModel model = trainer.fit(train);//訓練模型
Dataset<Row> result = model.transform(test);//預測測試集合
Dataset<Row> predictionAndLabels = result.select("prediction", "label");//篩選屬性
//校驗是否相等,計算相等的機率
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy");
System.out.println("測試正確率 = " + evaluator.evaluate(predictionAndLabels));
result.show(5, false);//顯示部分測試結果
spark.close();
byte[] bytes = toBytes(model);
System.out.println("算法長度爲:" + bytes.length);
model = (MultilayerPerceptronClassificationModel) toClass(bytes);//反序列化強轉
//預測新的數據
Vector denseVector = new DenseVector(new double[]{1, 88});
double predict = model.predict(denseVector);//使用的是ml的包,不是mllib的
System.out.println("反序列化預測:" + predict);
}
/**
* Title: toBytes
* Description:序列化對象
*
* @throws Exception
* @author zhengzx
*/
public static byte[] toBytes(Object out) throws Exception {
//用於序列化後存儲對象
ByteArrayOutputStream byteArrayOutputStream;
//java序列化API
ObjectOutputStream objectOutputStream = null;
try {
byteArrayOutputStream = new ByteArrayOutputStream();
objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
//將out對象進行序列化
objectOutputStream.writeObject(out);
//測試驗證輸入(獲取字節數組)
byte[] bs = byteArrayOutputStream.toByteArray();
//將數組轉化爲字符串
System.out.println(Arrays.toString(bs));
//todo 保存數組
return bs;
} catch (IOException e) {
e.printStackTrace();
} finally {
//關閉最外層的流(內部流會自動關閉)
objectOutputStream.close();
}
return null;
}
public static Object toClass(byte[] bs) throws Exception {
//創建存放二進制數據的API
ByteArrayInputStream byteArrayInputStream;
//創建反序列化對象
ObjectInputStream objectInputStream = null;
try {
byteArrayInputStream = new ByteArrayInputStream(bs);
objectInputStream = new ObjectInputStream(byteArrayInputStream);
//校驗測試
Object obj = objectInputStream.readObject();
return obj;
} catch (IOException e) {
e.printStackTrace();
} finally {
objectInputStream.close();
}
return null;
}