http://spark.apache.org/docs/2.2.0/ml-collaborative-filtering.html
不需要用戶和商品屬性的信息,這類算法通常稱爲協同過濾算法
例子:根據兩個用戶的年齡相同來判斷他們可能有相似的偏好,這不叫協同過濾。相反,根據兩個用戶播放過許多相同歌曲來判斷他們可能都喜歡某首歌,這才叫協同過濾。
SparkMLlib 的ALS算法 要求用戶和產品ID必須是數值型,這意味着大於Integer.MAX_VALUE(2147483647)的值都是非法的。
訓練出的模型可以保存到文件,還可以從文件load模型
package test import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.Model /** * Created by othc on 2018-01-19. */ object ALS1 { case class Rating(userId:Int,artistId:Int,count:Float) def main(args: Array[String]): Unit = { //session val spark = SparkSession.builder().config("spark.sql.warehouse.dir","/usr/local/testdata/spark-warehouse").appName("als").getOrCreate() import spark.implicits._ //用戶id 藝術家id 次數 val rawUserArtstData: Dataset[String] = spark.read.textFile("/usr/local/mldata/user_artist_data.txt") //藝術家id 名字 val rawArtistData = spark.read.textFile("/usr/local/mldata/artist_data.txt") val artistById = rawArtistData.flatMap(line => { val (id, name) = line.span(_ != '\t') if (name.isEmpty) { None } else { try { Some((id.toInt, name.trim)) } catch { case e: NumberFormatException => None } } }) //將錯誤的藝術家id或不標準的id 映射成藝術家正規的名字 val rawArtistAlias = spark.read.textFile("/usr/local/mldata/artist_alias.txt") val artistAlias = rawArtistAlias.flatMap(line=>{ val tokens = line.split("\t") if(tokens(0).isEmpty){ None }else{ Some((tokens(0).toInt,tokens(1).toInt)) } }).rdd.collectAsMap() //將map變量廣播 val bArtistAlias = spark.sparkContext.broadcast(artistAlias) val trainData = rawUserArtstData.map(line=>{ val Array(userId,artistId,count) = line.split(" ").map(_.toInt) val finalArtistId= bArtistAlias.value.getOrElse(artistId,artistId) Rating(userId,finalArtistId,count.toFloat) }).toDF().cache() val Array(train,test) = trainData.randomSplit(Array(0.8,0.2)) val als: ALS = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("artistId").setRatingCol("count") val model: ALSModel = als.fit(train) //去掉userid或artistId 是NAN的 model.setColdStartStrategy("drop") // 保存模型 // model.save("") // //加載模型 // import org.apache.spark.ml.recommendation.ALS._ // val load1: ALS = load("") val predictions: DataFrame = model.transform(test) val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("count").setPredictionCol("prediction") val rmse = evaluator.evaluate(predictions) println(s"Root-mean-square error = $rmse") //每個用戶推薦的前十個電影 val userRecs: DataFrame = model.recommendForAllUsers(10) userRecs.rdd.saveAsTextFile("/usr/local/testdata/") //每個電影推薦的十個用戶 val movieRecs = model.recommendForAllItems(10) movieRecs.rdd.saveAsTextFile("/usr/local/testdata/") userRecs.show() movieRecs.show() spark.stop() } }