基于ALS算法电影推荐(java版)

基于ALS算法的最佳电影推荐(java版)

package spark;

import java.util.Arrays;
import java.util.List;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.storage.StorageLevel;

import scala.Tuple2;

public class SparkALSDemo {

    public static void main(String ... args) throws Exception {
        Logger logger = Logger.getLogger(SparkALSDemo.class);
        // 设置日志的等级 并关闭jetty容器的日志
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
        Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
        // 设置运行环境,并创建SparkContext
        SparkConf sparkConf = new SparkConf().setAppName("MovieLensALS");
        sparkConf.setMaster("local[4]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);

        // 装载样本评分数据,并按照Timestamp模10的分为10份
        String movielensHomeDir = "F:/ml-1m";
        JavaRDD<Tuple2<Long, Rating>> ratings = jsc.textFile(movielensHomeDir + "/ratings.dat").map(
                line -> {
                    String[] fields = line.split("::");
                    return new Tuple2<Long, Rating>(Long.parseLong(fields[3]) % 10, new Rating(Integer.parseInt(fields[0]),
                            Integer.parseInt(fields[1]), Double.parseDouble(fields[2])));
                });

        // 装载用户评分,该评分由评分器生成(即生成文件personalRatings.txt)
        JavaRDD<String> data = jsc.textFile("F:/ml-1m/personalRatings.txt");
        JavaRDD<Rating> myRatingsRDD = data.map(s -> {
            String[] sarray = s.split("::");
            return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2]));
        });

        // 统计样本数据中的评分概要
        logger.info("Got " + ratings.count() + " ratings from " + ratings.map(tupe -> tupe._2.user()).distinct().count() + " users " + ratings.map(tupe -> tupe._2.product()).distinct().count() + " movies");
        // 用于训练是rating中key=[0-5]的数据
        JavaRDD<Rating> training = ratings.filter(x -> x._1 < 6).map(tupe2 -> tupe2._2).union(myRatingsRDD)
                .repartition(4).persist(StorageLevel.MEMORY_ONLY());
        // 用于校验是rating中key=[6-7]的数据
        JavaRDD<Rating> validation = ratings.filter(x -> x._1 >= 6 && x._1 < 8).map(tupe2 -> tupe2._2).repartition(4)
                .persist(StorageLevel.MEMORY_ONLY());
        // 用于测试的是rating中key=[8-9]的数据
        JavaRDD<Rating> test = ratings.filter(x -> x._1 >= 8).map(tupe2 -> tupe2._2).persist(StorageLevel.MEMORY_ONLY());
        logger.info("Training: " + training.count() + " validation: " + validation.count() + " test: " + test.count());


        // 定义不同的参数。计算均方根误差值,找到均方根误差值最小的模型。即:最优模型
        List<Integer> ranks = (List<Integer>)Arrays.asList(8, 10,  12);
        List<Double> lambdas = (List<Double>)Arrays.asList(0.1, 2.5, 5.0);
        List<Integer> numIters = (List<Integer>)Arrays.asList(10, 15, 20);
        MatrixFactorizationModel bestModel = null;
        double bestValidationRmse = Double.MAX_VALUE;
        int bestRank = 0;
        double bestLambda = -1.0;
        int bestNumIter = -1;
        for (int i = 0; i < ranks.size(); i++) {
            MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(training), ranks.get(i), numIters.get(i), lambdas.get(i));
            double validationRmse = SparkALSDemo.computeRMSEAverage(model, validation, validation.count());
            if (validationRmse < bestValidationRmse) {
                bestModel = model;
                bestValidationRmse = validationRmse;
                bestRank = ranks.get(i);
                bestLambda = lambdas.get(i);
                bestNumIter = numIters.get(i);
            }
        }
        double testRmse = SparkALSDemo.computeRMSEAverage(bestModel, test, test.count());
        logger.info("The best model was trained with rank = " + bestRank + " and lambda = " + bestLambda    + ", and numIter = " + bestNumIter + ", and its RMSE on the test set is " + testRmse + ".");

        // 创建一个基准数据集,该数据集是训练数据集[training]与校验数据集[validation]的交集.最优模型就是从这个基础数据集计算得来的
        JavaRDD<Double> rdd = training.union(validation).map(d -> d.rating());
        double meanRating = rdd.reduce((a, b) -> a + b) / rdd.count();
        double baselineRmse = Math.sqrt(test.map(x -> (meanRating - x.rating()) * (meanRating - x.rating())).reduce((a1, a2) -> a1 + a2)/ test.count());
        double improvement = (baselineRmse - testRmse) / baselineRmse * 100;
        logger.info("The best model improves the baseline by " + String.format("%1.2f", improvement) + "%.");

        // 加载电影数据
        JavaRDD<Tuple2<Integer, String>> movies = jsc.textFile(movielensHomeDir + "/movies.dat").map(line -> {
            String[] fields = line.split("::");
            return new Tuple2<Integer, String>(Integer.parseInt(fields[0]), fields[1]);
        });
        //将用户已经评过分的数据滤掉
        List<Integer> myRatedMovieIds = myRatingsRDD.map(d -> d.product()).collect();
        JavaRDD<Integer> candidates = movies.map(s -> s._1).filter(m -> !myRatedMovieIds.contains(m));

        //预测用户100最喜欢的10部电影
        JavaRDD<Rating> rr = bestModel.predict(JavaPairRDD.fromJavaRDD(candidates.map(d -> new Tuple2<Integer, Integer>(100, d)))).sortBy(f->f.rating(), false, 4);
        logger.info("Movies recommended for you:");
        rr.take(10).forEach(a -> logger.info("用户" + a.user() + "-[ " + a.product() + "]-[" + a.rating() + "]"));
        //jsc.stop();
    }

    /**
     * 根据模型model计算data的平均均方根误差
     * 
     * @param model
     * @param data
     * @param n
     * @return
     */
    public static double computeRMSEAverage(MatrixFactorizationModel model, JavaRDD<Rating> data, long n) {

        JavaRDD<Rating> jddRat = model.predict(JavaPairRDD.fromJavaRDD(data.map(d -> new Tuple2<Integer, Integer>(d.user(), d
                .product()))));
        JavaPairRDD<String, Double> pre = JavaPairRDD.fromJavaRDD(jddRat.map(f -> new Tuple2<String, Double>(f.user() + "_"
                + f.product(), f.rating())));
        JavaPairRDD<String, Double> rea = JavaPairRDD.fromJavaRDD(data.map(f -> new Tuple2<String, Double>(f.user() + "_"
                + f.product(), f.rating())));
        // 相当于SQl中的内联
        JavaRDD<Tuple2<Double, Double>> d = pre.join(rea).values();
        return d.map(f -> Math.pow(f._1 - f._2, 2)).reduce((a, b) -> a + b) / n;
    }
}

该文援引的是http://files.grouplens.org/datasets/movielens/ 中 ml-1m.zip的数据。下载加压到本地 修改代码中的路径即可

发布了32 篇原创文章 · 获赞 2 · 访问量 2万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章