[spark] mllib決策樹通過Strategy修改內存,java實現

spark的官方demo中並沒有告訴我們修改決策樹的內存,所以自己琢磨了一下,把那個配置文件的demo用java給弄了出來,代碼如下:

導入的包

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.impurity.Gini$;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

整體的代碼:

               SparkConf conf = new SparkConf().setAppName("test_spark").set("spark.executor.memory", "30m")
//			.setMaster("spark://spark-master:7077");
                .setMaster("local");


//        SparkConf conf = new SparkConf().setAppName("LR");
//        conf.set("spark.executor.memory", "1024m");
//        conf.set("spark.cores.max", "2");
//        conf.set("spark.driver.allowMultipleContexts", "true");
//        String ML_MASTER = StringUtil.getProperty("ML_MASTER");
//        conf.setMaster(ML_MASTER);
        JavaSparkContext sc = new JavaSparkContext(conf);
LabeledPoint pos1 = new LabeledPoint(1, Vectors.dense(2.0, 2.0, 2.0));
        LabeledPoint pos2 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos3 = new LabeledPoint(1, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos4 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos5 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos6 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos7 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        LabeledPoint pos8 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
        List<LabeledPoint> labeledPoints = new ArrayList<>();
        labeledPoints.add(pos1);
        labeledPoints.add(pos2);
        labeledPoints.add(pos3);
        labeledPoints.add(pos4);
        labeledPoints.add(pos5);
        labeledPoints.add(pos6);
        labeledPoints.add(pos7);
        labeledPoints.add(pos8);

        JavaRDD<LabeledPoint> b = sc.parallelize(labeledPoints);
        RDD<LabeledPoint> aa = b.rdd();
        scala.collection.immutable.Map<Object, Object> cate = new scala.collection.immutable.HashMap<Object, Object>();
        Map<Object, Object> categoricalFeaturesInfo = new HashMap<Object, Object>();
        int maxDepth = 5;
        int maxBins = 32;
        Gini$ gini = Gini.instance(); 
        Strategy strategy = new Strategy(org.apache.spark.mllib.tree.configuration.Algo.Classification(), gini, maxDepth, 3, maxBins, org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort(), cate, 1, 0.0, 512, 0.1, false, 0);
        DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);
        
        System.out.println("model decision:" + model.toDebugString());


結果應該是如下所示:

model decision:DecisionTreeModel classifier of depth 0 with 1 nodes
  Predict: 2.0


方法二: 下面這樣的strategy的修改內存也是可以的,上面的例子出現的結果有點怪異,所以又研究了其他的實現方法,下面的是能正確得到結果的初始化方法。

        Strategy strategy = Strategy.defaultStategy(org.apache.spark.mllib.tree.configuration.Algo.Classification());
        strategy.setNumClasses(5);
        strategy.setMaxMemoryInMB(512);
        DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);



這邊的作用僅僅是爲了看他是否能運行。各位需要自己修改的地方看下官網文檔吧。


這個是1.5.2的api,大家對着自己的spark版本看吧。

https://spark.apache.org/docs/1.5.2/api/java/org/apache/spark/mllib/tree/configuration/Strategy.html



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章