spark的官方demo中並沒有告訴我們修改決策樹的內存,所以自己琢磨了一下,把那個配置文件的demo用java給弄了出來,代碼如下:
導入的包
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.impurity.Gini$;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
整體的代碼:
SparkConf conf = new SparkConf().setAppName("test_spark").set("spark.executor.memory", "30m")
// .setMaster("spark://spark-master:7077");
.setMaster("local");
// SparkConf conf = new SparkConf().setAppName("LR");
// conf.set("spark.executor.memory", "1024m");
// conf.set("spark.cores.max", "2");
// conf.set("spark.driver.allowMultipleContexts", "true");
// String ML_MASTER = StringUtil.getProperty("ML_MASTER");
// conf.setMaster(ML_MASTER);
JavaSparkContext sc = new JavaSparkContext(conf);
LabeledPoint pos1 = new LabeledPoint(1, Vectors.dense(2.0, 2.0, 2.0));
LabeledPoint pos2 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos3 = new LabeledPoint(1, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos4 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos5 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos6 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos7 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos8 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
List<LabeledPoint> labeledPoints = new ArrayList<>();
labeledPoints.add(pos1);
labeledPoints.add(pos2);
labeledPoints.add(pos3);
labeledPoints.add(pos4);
labeledPoints.add(pos5);
labeledPoints.add(pos6);
labeledPoints.add(pos7);
labeledPoints.add(pos8);
JavaRDD<LabeledPoint> b = sc.parallelize(labeledPoints);
RDD<LabeledPoint> aa = b.rdd();
scala.collection.immutable.Map<Object, Object> cate = new scala.collection.immutable.HashMap<Object, Object>();
Map<Object, Object> categoricalFeaturesInfo = new HashMap<Object, Object>();
int maxDepth = 5;
int maxBins = 32;
Gini$ gini = Gini.instance();
Strategy strategy = new Strategy(org.apache.spark.mllib.tree.configuration.Algo.Classification(), gini, maxDepth, 3, maxBins, org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort(), cate, 1, 0.0, 512, 0.1, false, 0);
DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);
System.out.println("model decision:" + model.toDebugString());
結果應該是如下所示:
model decision:DecisionTreeModel classifier of depth 0 with 1 nodes
Predict: 2.0
方法二: 下面這樣的strategy的修改內存也是可以的,上面的例子出現的結果有點怪異,所以又研究了其他的實現方法,下面的是能正確得到結果的初始化方法。
Strategy strategy = Strategy.defaultStategy(org.apache.spark.mllib.tree.configuration.Algo.Classification());
strategy.setNumClasses(5);
strategy.setMaxMemoryInMB(512);
DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);
這邊的作用僅僅是爲了看他是否能運行。各位需要自己修改的地方看下官網文檔吧。
這個是1.5.2的api,大家對着自己的spark版本看吧。
https://spark.apache.org/docs/1.5.2/api/java/org/apache/spark/mllib/tree/configuration/Strategy.html