什麼是state(狀態)管理?我們以wordcount爲例。每個batchInterval會計算當前batch的單詞計數,那如果需要單詞計數一直的累加下去,該如何實現呢?SparkStreaming提供了兩種方法:updateStateByKey和mapWithState 。mapWithState 是1.6版本新增功能,目前屬於實驗階段。mapWithState具官方說性能較updateStateByKey提升10倍。那麼我們來看看他們到底是如何實現的。
代碼示例如下:
object UpdateStateByKeyDemo { def main(args: Array[String]) { val conf = new SparkConf().setAppName("UpdateStateByKeyDemo") val ssc = new StreamingContext(conf,Seconds(20)) //要使用updateStateByKey方法,必須設置Checkpoint。 ssc.checkpoint("/checkpoint/") val socketLines = ssc.socketTextStream("spark-master",9999) socketLines.flatMap(_.split(",")).map(word=>(word,1)) .updateStateByKey( (currValues:Seq[Int],preValue:Option[Int]) =>{ val currValue = currValues.sum Some(currValue + preValue.getOrElse(0)) }).print() // socketLines.flatMap(_.split(",")).map(word=>(word,1)).reduceByKey() ssc.start() ssc.awaitTermination() ssc.stop() } }
我們知道map返回的是MappedDStream,而MappedDStream並沒有updateStateByKey方法,並且它的父類DStream中也沒有該方法。
但是DStream的伴生對象中有一個隱式轉換函數
implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairDStreamFunctions[K, V] = { new PairDStreamFunctions[K, V](stream) }
在PairDStreamFunction中有updateStateByKey的定義:
def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner()) }
它接收一個函數作爲參數,Seq[V]表示當前batch對應的key的value,而Option[S]表示key的以前的累計值(以示例爲準),返回值是新的狀態值。
updateStateByKey最終會調用如下同名函數
def updateStateByKey[S: ClassTag]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) }
在這裏面new出了一個StateDStream對象。在其compute方法中,會先獲取上一個batch計算出的RDD(包含了至程序開始到上一個batch單詞的累計計數),然後在獲取本次batch中StateDStream的父類計算出的RDD(本次batch的單詞計數)分別是prevStateRDD和parentRDD,然後在調用
private [this] def computeUsingPreviousRDD ( parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = { // Define the function for the mapPartition operation on cogrouped RDD; // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) }
兩個RDD進行cogroup然後應用updateStateByKey傳入的函數。cogroup的性能是比較低下的。
所以Spark1.6 引入了mapWithState。
代碼示例如下:
object mapWithStateTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount").setMaster("local[2]") val ssc = new StreamingContext(sparkConf, Seconds(5)) ssc.checkpoint(".") // Initial state RDD for mapWithState operation val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) val lines = ssc.socketTextStream("spark-master", 9999) val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => { val sum = one.getOrElse(0) + state.getOption.getOrElse(0) val output = (word, sum) state.update(sum) output } val stateDstream = wordDstream.mapWithState( StateSpec.function(mappingFunc) ) stateDstream.print() ssc.start() ssc.awaitTermination() } }
mapWithState接收的參數是一個StateSpec對象。在StateSpec中封裝了狀態管理的函數
mapWithState函數中創建了MapWithStateDStreamImpl對象
def mapWithState[StateType: ClassTag, MappedType: ClassTag]( spec: StateSpec[K, V, StateType, MappedType] ): MapWithStateDStream[K, V, StateType, MappedType] = { new MapWithStateDStreamImpl[K, V, StateType, MappedType]( self, spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]] ) }
而在MapWithStateDStreamImpl中有創建了一個InternalMapWithStateDStream。並且MapWithStateDStreamImpl的compute方法調用了InternalMapWithStateDStream的getOrCompute方法
private val internalStream = new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec) override def slideDuration: Duration = internalStream.slideDuration override def dependencies: List[DStream[_]] = List(internalStream) override def compute(validTime: Time): Option[RDD[MappedType]] = { internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } } }
我們先看InternalMapWithStateDStream的getOrCompute方法:
在InternalMapWithStateDStream中並沒有實現getOrCompute方法,是其父類DStream中實現的。
而在getOrCompute方法中最終會調用InternalMapWithStateDStream的compute方法:
/** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { case Some(rdd) => if (rdd.partitioner != Some(partitioner)) { // If the RDD is not partitioned the right way, let us repartition it using the // partition index as the key. This is to ensure that state RDD is always partitioned // before creating another state RDD using it MapWithStateRDD.createFromRDD[K, V, S, E]( rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) } else { rdd } case None => MapWithStateRDD.createFromPairRDD[K, V, S, E]( spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), partitioner, validTime ) } // Compute the new state RDD with previous state RDD and partitioned data RDD // Even if there is no data RDD, use an empty one to create a new state RDD val dataRDD = parent.getOrCompute(validTime).getOrElse { context.sparkContext.emptyRDD[(K, V)] } val partitionedDataRDD = dataRDD.partitionBy(partitioner) val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => (validTime - interval).milliseconds } Some(new MapWithStateRDD( prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) } }
在這裏根據先前的狀態prevStateRDD,和MappedDStream中計算的當前batch對應的RDD生成了一個MapWithStateRDD,compute方法如下:
override def compute( partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] val prevStateRDDIterator = prevStateRDD.iterator( stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None val newRecord = MapWithStateRDDRecord.updateRecordWithData( prevRecord, dataIterator, mappingFunction, batchTime, timeoutThresholdTime, removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled ) Iterator(newRecord) }
MapWithStateRDD 的一個分區,對應一個MapWithStateRDDRecord對象,在MapWithStateRDDRecord中維護了兩個數據結構
var stateMap: StateMap[K, S], var mappedData: Seq[E])
分別用來存儲狀態和mappingFunction的返回值。
在updateRecordWithData方法中
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( prevRecord: Option[MapWithStateRDDRecord[K, S, E]], dataIterator: Iterator[(K, V)], mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long], removeTimedoutData: Boolean ): MapWithStateRDDRecord[K, S, E] = { // Create a new state map by cloning the previous one (if it exists) or by creating an empty one val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } val mappedData = new ArrayBuffer[E] val wrappedState = new StateImpl[S]() // Call the mapping function on each record in the data iterator, and accordingly // update the states touched, and collect the data returned by the mapping function dataIterator.foreach { case (key, value) => wrappedState.wrap(newStateMap.get(key)) val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) } else if (wrappedState.isUpdated || (wrappedState.exists && timeoutThresholdTime.isDefined)) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } mappedData ++= returned } // Get the timed out state records, call the mapping function on each and collect the // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => wrappedState.wrapTimingOutState(state) val returned = mappingFunction(batchTime, key, None, wrappedState) mappedData ++= returned newStateMap.remove(key) } } MapWithStateRDDRecord(newStateMap, mappedData) } }
維護狀態值,並且返回MapWithStateRDDRecord.
使用如下流程圖說明計算過程:
備註:
1、DT大數據夢工廠微信公衆號DT_Spark
2、IMF晚8點大數據實戰YY直播頻道號:68917580
3、新浪微博: http://www.weibo.com/ilovepains