版本信息
spark version 2.3.3
jdk 1.8
idea 2019
MacBook Pro
spark的shuffle過程連接了job的前後兩個stage
除了第一個stage的數據是讀取hdfs,hbase,hive等等之外
其他的stage的數據都要利用ShuffleReader抓取數據
ShuffleReader
ShuffleReader是一個trait, 從註釋看,mappers是一個複數
意思是從多個位置拉取數據
/**
* Obtained inside a reduce task to read combined records from the mappers.
*/
private[spark] trait ShuffleReader[K, C] {
/** Read the combined key-values for this reduce task */
def read(): Iterator[Product2[K, C]]
/**
* Close this reader.
* TODO: Add this back when we make the ShuffleReader a developer API that others can implement
* (at which point this will likely be necessary).
*/
// def stop(): Unit
}
ShuffleReader只有一個實現類
BlockStoreShuffleReader
我們看下reader方法的具體實現
首先直接定義了一個ShuffleBlockFetcherIterator類
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
我們看下這個類的功能
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
* This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
* @param shuffleClient [[ShuffleClient]] for fetching remote blocks
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging
我們不管這個類實現的具體細節
它讀取shuffle map端的block,返回一個(BlockId, InputStream)
這明顯是包含多個流
map端的task運行最後的結果只有2個文件,一個index,一個data
按照一個map產生一個(BlockId, InputStream)
通常稍微大一點的計算任務,就會有4,5百個task
現在是多個流,然後合併成一個流
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
scala collection上的flatMap方法,想必都熟悉
iterator上的flatMap方法與此類似,看下源碼
/** Creates a new iterator by applying a function to all values produced by this iterator
* and concatenating the results.
*
* @param f the function to apply on each element.
* @return the iterator resulting from applying the given iterator-valued function
* `f` to each value produced by this iterator and concatenating the results.
* @note Reuse: $consumesAndProducesIterator
*/
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
private var cur: Iterator[B] = empty
private def nextCur() { cur = f(self.next()).toIterator }
def hasNext: Boolean = {
// Equivalent to cur.hasNext || self.hasNext && { nextCur(); hasNext }
// but slightly shorter bytecode (better JVM inlining!)
while (!cur.hasNext) {
if (!self.hasNext) return false
nextCur()
}
true
}
def next(): B = (if (hasNext) cur else empty).next()
}
我們看到內部方法
hasNext中再次調用了hasNext
next再次調用了next
如果我們手動運行演算一下代碼
不難得出,這個flatMap方法把多個流合併成了一個流
先迭代第一個流,再迭代第二個流 … 依次迭代,直到最後一個流
整個成了單一的流
緊跟着度量系統,這些度量信息,在spark UI上都能看到
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
這些合併之後的流並沒有被觸發計算,因爲截止到目前,還沒有人調用
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
聚合會觸發計算
org.apache.spark.Aggregator#combineCombinersByKey
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
combiners.insertAll(iter)
我們拉取的數據已經是同一個分區的數據
// Sort the output if there is a sort ordering defined.
val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener(_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
如果key需要排序,會觸發計算,這一步和聚合一樣
數據全部拉取一遍到本地,才能完成要求
sorter.insertAll(aggregatedIter)
ShuffledRDD
ShuffleReader拉取數據一般是在RDD方法compute中
org.apache.spark.rdd.ShuffledRDD#compute
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
在生產中,看到最多的就是ShuffledRDD
這樣我們大致瞭解ShuffleReader的拉取數據的過程
個人覺得需要記住一點
如果要聚合或者key排序,數據將強制拉取到本地處理
作爲對比,其他情況,可以一點一點的迭代拉取數據