Spark — ShuffleReader過程

Shuffle Reader

  在之前的博客中,分析了shuffle map端的操作,map最終會將輸出文件信息封裝爲一個MapStatus發送給Driver,然後ResultTask或ShuffleMapTask在拉取數據的時候,會先去Driver上拉取自己要讀取數據的信息,比如在哪個節點上,以及在文件中的位置。下面我們來分析一下ShuffleReader,首先Map操作結束之後產生的RDD是ShuffledRDD,它會調用ShuffleManager的getReader()方法,這個方法裏面傳入了上一個stage的信息,拉取文件信息的offset,接着調用它的read()方法:

ShuffledRDD的compute()和BlockStoreShuffleReader的read()方法
 override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
	 // ResultTask或ShuffleMapTask,在生成ShuffledRDD並處理的時候
    // 會調用它的compute方法,來計算當前這個RDD的partition的數據
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    // 這裏調用ShuffleManager的getReader的read()方法,拉取ResultTask或ShuffleMapTask所需的數據
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
}

override def read(): Iterator[Product2[K, C]] = {
    // BlockStoreReader實例化的時候,傳入的參數會獲取MapOutputTracker對象,
    // 調用其getMapSizesByExecutorId方法,創建一個Iterator,用於遍歷待獲取數據的位置信息。
    // 注意傳入的參數,shuffleId,代表上一個stage;
    // startPartition:是當前需要的數據在輸出文件中的起始offset,endPartition:是結束offset
    // 通過這兩個限制從MapOutputTracker上拉取所需信息在節點上的位置信息
    // 在實例化ShuffleBlockFetcherIterator的時候,會調用它的initialize()方法,
    // 在這個方法裏面,會根據拉取到的文件位置信息去對應的worker節點的BlockManager上拉取數據。
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 獲取數據的位置信息
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // 創建數據輸入流讀取數據,以及是否需要解壓等
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }
    // 創建序列化實例
    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    // 將讀取到的數據進行反序列化操作
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    // 下面就是對數據的一些操作,比如是否需要聚合,排序等等
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

  這裏最重要的在實例化ShuffleBlockFetcherIterator()的時候,就會去遠程讀取數據,這裏面有兩個重要的方法,一個是獲取要拉取文件的信息getMapSizesByExecutorId(),還有一個是ShuffleBlockFetcherIterator在實例化的時候調用的initialize()方法,下面我們先分析如何拉取當前ResultTask(或ShuffleMapTask)所需信息的位置:

MapOutputTracker的getMapSizesByExecutorId()方法

  首先我們看一下源碼:

 def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    // 獲取數據的位置信息
    val statuses = getStatuses(shuffleId)
    // Synchronize on the returned array because, on the driver, it gets mutated in place
    statuses.synchronized {
      // 將獲取到的數據存儲到BlockManager上。
      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }

  其實這個方法裏面封裝了兩個兩個子方法,一個是獲取數據位置信息的getStatus;還有一個就是將獲取到的信息提取出來放入隊列中。
下面我們先看getStatus()

getStatuses
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
    // 獲取shuffleId在輸出文件中的每個partition寫入位置offset,
    // 看一下當前緩存是否有之前拉取的數據
    val statuses = mapStatuses.get(shuffleId).orNull
    // 如果不存在
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      val startTime = System.currentTimeMillis
      var fetchedStatuses: Array[MapStatus] = null
      // 有可能其他的ResultTask在拉取這個shuffleId的數據,等待對方拉取完成
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        while (fetching.contains(shuffleId)) {
          try {
            // 等待喚醒
            fetching.wait()
          } catch {
            case e: InterruptedException =>
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        // 再次獲取一遍
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          // 將其加入等待隊列,下一次優先進行數據拉取
          fetching += shuffleId
        }
      }

      // 獲得當前拉取數據的權限
      if (fetchedStatuses == null) {
        // We won the race to fetch the statuses; do so
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        // This try-finally prevents hangs due to timeouts:
        try {
          // 發送GetMapOutputStatuses消息,從MapOutputTracker上拉取數據
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
          // 將獲取的數據反序列化
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          // 將拉取到的數據
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            // 清除當前等待的shuffleId
            fetching -= shuffleId
            // 喚醒其他線程
            fetching.notifyAll()
          }
        }
      }
      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
        s"${System.currentTimeMillis - startTime} ms")

      if (fetchedStatuses != null) {
        return fetchedStatuses
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      // 假如緩存有之前拉取的數據,那麼直接返回
      return statuses
    }
  }

  首先看一下當前緩存是否已經包含這個shuffleId輸出文件信息,假如包含,那麼就這就返回即可。假設沒有,如果fetching等待隊列中包含當前需要拉取的shuffleId,先阻塞在這邊等待其他ResultTask(或ShuffleMapTask)獲取完成;被喚醒以後接着獲取一次status,假設還沒有獲取到,那麼就開始獲取。開啓拉取數據信息,使用的是askTracker()方法,參數是GetMapOutputStatus信息,它向Driver的MapOutputTracker發送這條信息,去獲取當前這個ShuffleId的輸出文件信息,Driver上的MapOutputTracker接收到這條信息後,就會獲取當前這shuffleId的相關信息,然後在將獲取到的信息發送給當前這個ResultTask(或ShuffleMapTask)。這裏fetchedStatuses就是Driver端MapOutputTracker發送過來的待獲取數據的位置信息。然後將數據反序列化存入map緩存中;接着在喚醒其他等待線程。
  在獲取到需要拉取數據的位置信息之後,就調用convertMapStatuses()解析剛剛獲取到的位置信息,將要拉取的位置信息提取出來,放入隊列中,並返回。
  上面這個就獲取到了需要拉取數據的位置信息,那麼下一步就是去拉取數據,拉取數據的過程就在實例化ShuffleBlockFetcherIterator的時候,調用的initialize()方法中。

ShuffleBlockFetcherIterator的初始化方法initialize()
/**
    *   將這個方法作爲入口,開始拉取ResultTask對應的多份數據
    */
  private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.
    // 切分本地和遠程Block
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    // 切分完Block之後,進行隨機排序操作
    fetchRequests ++= Utils.randomize(remoteRequests)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    // 循環往復拉取數據,只要發現數據還沒有拉取完,就發送請求到遠程拉取數據
    // 這裏有一個參數比較重要,就是maxBytesInFlight,代表ResultTask最多能拉取多少數據
    // 到本地,就要開始進行自定義的reduce算子的處理
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    // 獲取本地的數據
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

  首先對將要拉取的數據信息進行區分,切分爲本地和遠程拉取,首先拉取遠程worker節點上的數據,fetchUpToMaxBytes(),它會不斷的拉取數據,直到數據拉取完或者當前拉取的緩存以及滿了(默認48M,maxBytesInFlight),然後接着調用fetchLocalBlocks(),拉取在本地節點上的數據。這樣這個ResultTask(或ShuffleMapTask)的數據就拉取到本地緩存了。這裏我們先不對fetchUpToMaxBytes和fetchLocalBlocks做詳細的分析了。
  總結一下,這裏主要是和Driver端的MapOutputTracker進行通信,獲取當前ResultTask(或ShuffleMapTask)要拉取的文件的位置信息,從獲取到的文件位置信息裏提取出當前這個Task所需的位置信息,然後通過BlockManager去遠程或本地拉取需要的信息這裏有個參數需要注意一下(spark.reducer.maxSizeInFlight,默認48M,代表當前reduce端最大能存儲的拉取數據緩存大小)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章