pom
<dependencies> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>com.thoughtworks.paranamer</groupId> <artifactId>paranamer</artifactId> <version>2.8</version> </dependency> </dependencies> <!--打可執行jar包--> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.3</version> <configuration> <source>1.8</source> <target>1.8</target> <encoding>UTF-8</encoding> </configuration> </plugin> </plugins> <resources> <resource> <directory>src/main/resources</directory> <includes> <include>**/*.*</include> </includes> </resource> </resources> </build>
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.DoubleFunction; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.*; import scala.Tuple2; import java.util.Arrays; public class Logistic { public static void main(String[] args) { SparkConf sparkConf = new SparkConf().setAppName("Regression").setMaster("local[1]"); JavaSparkContext sc = new JavaSparkContext(sparkConf); sc.setLogLevel("ERROR"); //加載數據 JavaRDD<String> data = sc.textFile("D:\\IdeaProjects\\SparkMLlib\\src\\test\\java\\data4"); JavaRDD<LabeledPoint> parsedData = data.map(line -> { String[] parts = line.split(","); double[] ds = Arrays.stream(parts[1].split(" ")) .mapToDouble(Double::parseDouble) .toArray(); return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(ds));//賦值f(x),Vectors.dense(Array(x)) }).cache(); int numIterations = 1000; //循環迭代修改參數的次數 //三種訓練方式 LinearRegressionModel model1 = new LinearRegressionWithSGD(1.0D, numIterations, 0.0D, 1.0D).setIntercept(true).run(parsedData.rdd());//使用截距.setIntercept(true) //LinearRegressionModel model1 = LinearRegressionWithSGD.train(parsedData.rdd(), numIterations);//不使用截距的方式 RidgeRegressionModel model2 = new RidgeRegressionWithSGD(1.0D, numIterations, 0.0D, 1.0D).run(parsedData.rdd()); LassoModel model3 = LassoWithSGD.train(parsedData.rdd(), numIterations);//使用默認的參數 //統計預測原始數據的方差 print(parsedData, model1); print(parsedData, model2); print(parsedData, model3); //預測一條新數據方法 double[] d = new double[]{0}; Vector v = Vectors.dense(d); System.out.println("預測結果爲:" + model1.predict(v)); System.out.println("預測結果爲:" + model2.predict(v)); System.out.println("預測結果爲:" + model3.predict(v)); } //用模型預測訓練數據,並計算模型的預測誤差 public static void print(JavaRDD<LabeledPoint> parsedData, GeneralizedLinearModel model) { JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> { double prediction = model.predict(point.features()); return new Tuple2<>(point.label(), prediction); }); Double MSE = valuesAndPreds.mapToDouble(new DoubleFunction<Tuple2<Double, Double>>() { @Override public double call(Tuple2<Double, Double> doubleDoubleTuple2) throws Exception {//計算預測值與實際值差值的平方值的均值 System.out.println("實際值:" + doubleDoubleTuple2._1() + ", 預測值:" + doubleDoubleTuple2._2()); return Math.pow(doubleDoubleTuple2._1() - doubleDoubleTuple2._2(), 2); } }).mean(); System.out.println(model.getClass().getName() + " 方差 = " + MSE); } }