pom
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>com.thoughtworks.paranamer</groupId>
<artifactId>paranamer</artifactId>
<version>2.8</version>
</dependency>
</dependencies>
<!--打可執行jar包-->
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
</plugins>
<resources>
<resource>
<directory>src/main/resources</directory>
<includes>
<include>**/*.*</include>
</includes>
</resource>
</resources>
</build>
代碼
import java.util.HashMap; import java.util.Map; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; import scala.Tuple2; public class MyRandomForest { public static void main(String[] args) { SparkConf sparkConf = new SparkConf().setAppName("app").setMaster("local[1]"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); // 加載數據 String path = "D:\\IdeaProjects\\SparkMLlib\\src\\test\\java\\data4"; JavaRDD<String> javaRDD = jsc.textFile(path); JavaRDD<LabeledPoint> data = javaRDD.map(new Function<String, LabeledPoint>() { @Override public LabeledPoint call(String line) throws Exception { String[] split = line.split(","); String[] arr = split[1].split(" "); double[] vectors = new double[arr.length]; for (int i = 0; i < arr.length; i++) { vectors[i] = Double.parseDouble(arr[i]); } LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(split[0]), Vectors.dense(vectors)); return labeledPoint; } }); // 將數據集劃分爲訓練數據和測試數據 JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});//將數據分成7:3 JavaRDD<LabeledPoint> training = splits[0]; JavaRDD<LabeledPoint> testData = splits[1]; // 隨機森林模型訓練 Integer numClasses = 2;//劃分的類型數量 Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>(); Integer numTrees = 1; // 樹的數量 String featureSubsetStrategy = "auto"; //算法自動選擇 auto/all String impurity = "gini";//隨機森林有三種方式,entropy,gini,variance,迴歸肯定就是variance Integer maxDepth = 10;//深度 Integer maxBins = 32;//數據最大分端數 Integer seed = 1000000;//採樣種子,種子不變,採樣結果不變 //訓練模型 RandomForestModel model = RandomForest.trainClassifier( training, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed ); //測試數據預測 JavaPairRDD<Double, Double> predictionAndLabel = testData .mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); //計算錯誤率 double testErr = predictionAndLabel.filter(pl -> !pl._1.equals(pl._2())).count() / (double) testData.count(); System.out.println("Test err:" + testErr); //打印樹形結構 System.out.println(model.toDebugString()); //新數據預測 Vector v = Vectors.dense(new double[]{3, 8}); System.out.println("預測爲" + model.predict(v)); } }