環境版本: ·Spark 2.0 ·Scala 2.11.8
在網上搜索Spark MLlib和Spark Streaming結合的例子幾乎沒有,我很疑惑,難道實現準實時預測有別的更合理的方式?望大佬在評論區指出。本篇博客思路很簡單,使用Spark MLlib訓練並保存模型,然後編寫Spark Streaming程序讀取並使用模型。需注意的是,在使用Spark MLlib之前我使用了python查看分析數據、清洗數據、特徵工程、構造數據集、訓練模型等等,且在本篇中直接使用了python構造的數據集。
1 訓練並保存模型
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD
/**
* 訓練模型
* Created by drguo on 2020/5/20 11:34.
*/
object RandomForestM {
def main(args: Array[String]) {
val sparkConf = new SparkConf()
// 本地模式,* 自動檢測cpu核心,佔滿
.setMaster("local[*]")
.setAppName("rf")
val sc = new SparkContext(sparkConf)
// 讀取數據
val rawData = sc.textFile("hdfs://xx:8020/model/data/xx.csv")
val data = rawData.map { line =>
val values = line.split(",").map(_.toDouble)
// init返回除了最後一個元素的所有元素,作爲特徵向量
// Vectors.dense向量化,dense密集型
val feature = Vectors.dense(values.init)
val label = values.last
LabeledPoint(label, feature)
}
// 訓練集、交叉驗證集和測試集,各佔80%,10%和10%
// 10%的交叉驗證數據集的作用是確定在訓練數據集上訓練出來的模型的最好參數
// 測試數據集的作用是評估CV數據集的最好參數
val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
trainData.cache()
cvData.cache()
testData.cache()
// 構建隨機森林
val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", "gini", 4, 32)
val metrics = getMetrics(model, cvData)
// 混淆矩陣和模型精確率
println(metrics.confusionMatrix)
println(metrics.accuracy)
// 每個類別對應的精確率與召回率
(0 until 2).map(target => (metrics.precision(target), metrics.recall(target))).foreach(println)
// 保存模型
model.save(sc, "hdfs://xx:8020/model/xxModel")
}
/**
* @param model 隨機森林模型
* @param data 用於交叉驗證的數據集
**/
def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
// 將交叉驗證數據集的每個樣本的特徵向量交給模型預測,並和原本正確的目標特徵組成一個tuple
val predictionsAndLables = data.map { d =>
(model.predict(d.features), d.label)
}
// 將結果交給MulticlassMetrics,其可以以不同的方式計算分配器預測的質量
new MulticlassMetrics(predictionsAndLables)
}
/**
* 在訓練數據集上得到最好的參數組合
*
* @param trainData 訓練數據集
* @param cvData 交叉驗證數據集
**/
def getBestParam(trainData: RDD[LabeledPoint], cvData: RDD[LabeledPoint]): Unit = {
val evaluations = for (impurity <- Array("gini", "entropy");
depth <- Array(1, 20);
bins <- Array(10, 300)) yield {
val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", impurity, depth, bins)
val metrics = getMetrics(model, cvData)
((impurity, depth, bins), metrics.accuracy)
}
evaluations.sortBy(_._2).reverse.foreach(println)
}
}
2 讀取並使用模型
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.kafka010.{HasOffsetRanges, KafkaUtils, OffsetRange}
/**
* Created by drguo on 2020/5/20 11:34.
*/
object ModelTest {
private val brokers = "xx1:6667,xx2:6667,xx3:6667"
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf()
// 本地模式,* 自動檢測cpu核心,佔滿
.setMaster("local[*]")
.setAppName("ModelTest")
sparkConf.set("spark.sql.warehouse.dir","file:///") // 本地
val sc = new SparkContext(sparkConf)
//讀取模型
val rfModel = RandomForestModel.load(sc, "hdfs://xx:8020/model/xxModel")
val ssc = new StreamingContext(sc, Seconds(6))
val topics = Array("xx1", "xx2")
val kafkaParams = Map[String, Object](
"bootstrap.servers" -> brokers,
"key.deserializer" -> classOf[StringDeserializer],
"value.deserializer" -> classOf[StringDeserializer],
"group.id" -> "hqc",
"auto.offset.reset" -> "latest",
"enable.auto.commit" -> (false: java.lang.Boolean)
)
val messages: InputDStream[ConsumerRecord[String, String]] = KafkaUtils.createDirectStream[String, String](
ssc,
PreferConsistent,
Subscribe[String, String](topics, kafkaParams)
)
// 1個 partition 分區
messages.foreachRDD(rdd => {
val offsetRanges: Array[OffsetRange] = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
rdd.foreach((msg: ConsumerRecord[String, String]) => {
val o = offsetRanges(TaskContext.get.partitionId)
// println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}")
val topic: String = o.topic
topic match {
case "xx1" =>
val line = KxxDataClean(msg.value)
if (line != "") {
val values = line.split(",").map(_.toDouble)
val feature = Vectors.dense(values)
//進行預測
val preLabel = rfModel.predict(feature)
println(preLabel)
}
case "xx2" =>
}
})
})
ssc.start()
ssc.awaitTermination()
}
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xx</groupId>
<artifactId>spark-xx-model</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<spark.version>2.0.0</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
<scope>compile</scope>
</dependency>
<!--spark streaming + kafka-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
<version>${spark.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka_2.11</artifactId>
<version>0.10.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka-0-10_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<!--mysql-->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.39</version>
</dependency>
<!--日誌-->
<dependency>
<groupId>com.typesafe.scala-logging</groupId>
<artifactId>scala-logging_2.11</artifactId>
<version>3.7.2</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.3</version>
<configuration>
<classifier>dist</classifier>
<appendAssemblyId>true</appendAssemblyId>
<descriptorRefs>
<descriptor>jar-with-dependencies</descriptor>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>