Spark Task原理與源碼分析

① task的原理示意圖


②task源碼分析

Executor.scala

  /**
    * 這裏就是task運行的工作原理
    */
  class TaskRunner(
      execBackend: ExecutorBackend,
      val taskId: Long,
      val attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer)
    extends Runnable {
          ...
    override def run(): Unit = {
      ...

      try {
        // 對task數據,反序列化
        val (taskFiles, taskJars, taskProps, taskBytes) =
          Task.deserializeWithDependencies(serializedTask)

        // Must be set before updateDependencies() is called, in case fetching dependencies
        // requires access to properties contained within (e.g. for access control).
        Executor.taskDeserializationProps.set(taskProps)

        // 將依賴的文件資源、jar拷貝到到task讀取文件的對應目錄
        updateDependencies(taskFiles, taskJars)

        // 反序列化task的數據集
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
        task.localProperties = taskProps
        task.setTaskMemoryManager(taskMemoryManager)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        // task執行開始時間
        taskStart = System.currentTimeMillis()
        taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        var threwException = true

        val value = try {

          // 這裏的res就是MapStatus
          // 如果後面執行的還是一個ShuffleMapTask,就會聯繫MaoOutputTracker
          // 獲取上一個ShuffleMapTask的輸出結果。 ResultTask也是一樣的。
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          res
        } finally {
          val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

          if (freedMemory > 0 && !threwException) {
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logWarning(errMsg)
            }
          }

          if (releasedLocks.nonEmpty && !threwException) {
            val errMsg =
              s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
                releasedLocks.mkString("[", ", ", "]")
            if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logWarning(errMsg)
            }
          }
        }

        // task執行結束時間
        val taskFinish = System.currentTimeMillis()
        val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }

        // 對MapStatus 序列化和封裝,因爲要發送給driver
        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        // Deserialization happens in two parts: first, we deserialize a Task object, which
        // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
        // 這裏計算task相關的統計信息,包括
        // 反序列化耗時長、Java虛擬機GC耗時長、數據結果序列化耗時長
        // 這些指標都會在SparkUI上顯示
        task.metrics.setExecutorDeserializeTime(
          (taskStart - deserializeStartTime) + task.executorDeserializeTime)
        task.metrics.setExecutorDeserializeCpuTime(
          (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
        // We need to subtract Task.run()'s deserialization time to avoid double-counting
        task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
        task.metrics.setExecutorCpuTime(
          (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
        task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
        task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)

        // Note: accumulator updates must be collected after TaskMetrics is updated
        val accumUpdates = task.collectAccumulatorUpdates()
        // TODO: do not serialize value twice
        val directResult = new DirectTaskResult(valueBytes, accumUpdates)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize > maxDirectResultSize) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId,
              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
              StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }

        // 調用了executor所在的CoraseGrainedBackend的statusUpdate()方法
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
      ...
  }
}
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
    // 獲取hadoop配置文件
    lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)

    // java的多線程併發訪問同步
    // CoarseGrainedExecutorBackend是併發運行 的,當訪問一些共享資源的時候,可能會出現多線程併發安全問題
    // 下面代碼訪問了currentFile等共享資源文件
    synchronized {
      // Fetch missing dependencies 遍歷拉取的文件
      for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        // Fetch file with useCache mode, close cache for local mode.
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
          env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
        currentFiles(name) = timestamp
      }

      // 遍歷拉取的jar
      for ((name, timestamp) <- newJars) {
        val localName = name.split("/").last
        val currentTimeStamp = currentJars.get(name)
          .orElse(currentJars.get(localName))
          .getOrElse(-1L)

        // 要求jar的當前時間戳小於目標時間戳
        if (currentTimeStamp < timestamp) {
          logInfo("Fetching " + name + " with timestamp " + timestamp)
          // Fetch file with useCache mode, close cache for local mode.
          Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
            env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
          currentJars(name) = timestamp
          // Add it to our class loader
          val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
          if (!urlClassLoader.getURLs().contains(url)) {
            logInfo("Adding " + url + " to class loader")
            urlClassLoader.addURL(url)
          }
        }
      }
    }

Task.scala

  final def run(
      taskAttemptId: Long,
      attemptNumber: Int,
      metricsSystem: MetricsSystem): T = {

    SparkEnv.get.blockManager.registerTask(taskAttemptId)

    // task執行上下文,記錄了task執行的全局性的數據,例如task重試次數,屬於哪個stage,rdd處理的partition等
    context = new TaskContextImpl(
      stageId,
      partitionId,
      taskAttemptId,
      attemptNumber,
      taskMemoryManager,
      localProperties,
      metricsSystem,
      metrics)
    TaskContext.setTaskContext(context)
    taskThread = Thread.currentThread()

    if (_killed) {
      kill(interruptThread = false)
    }

    new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId),
      Option(taskAttemptId), Option(attemptNumber)).setCurrentContext()

    try {
      // 調用抽象方法
      runTask(context)
    } catch {
      case e: Throwable =>
        // Catch all errors; run task failure callbacks, and rethrow the exception.
        try {
          context.markTaskFailed(e)
        } catch {
          case t: Throwable =>
            e.addSuppressed(t)
        }
        throw e
    } finally {
      // Call the task completion callbacks.
      context.markTaskCompleted()
      try {
        Utils.tryLogNonFatalError {
          // Release memory used by this thread for unrolling blocks
          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
          // Notify any tasks waiting for execution memory to be freed to wake up and try to
          // acquire memory again. This makes impossible the scenario where a task sleeps forever
          // because there are no other tasks left to notify it. Since this is safe to do but may
          // not be strictly necessary, we should revisit whether we can remove this in the future.
          val memoryManager = SparkEnv.get.memoryManager
          memoryManager.synchronized { memoryManager.notifyAll() }
        }
      } finally {
        TaskContext.unset()
      }
    }
  }
  /**
    * 這個類只是一個模板類或者抽象類,僅僅封裝了一些子類通用的數據和操作
    * 細節的實現全部在子類,task的子類,如ShuffleMapTask、ResultTask
    * @param context
    * @return
    */
  def runTask(context: TaskContext): T

ShuffleMapTask.scala

// ShuffleMapTask將rdd的元素,切分爲多個bucket
// 基於ShuffleDependency指定的partitioner,默認就是HashPartitioner
private[spark] class ShuffleMapTask(
   ...
   // ShuffleMapTask的 runTask 有 MapStatus返回值
  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L

    // 對task要處理的數據,做反序列化操作
    /*
       問題:多個task在executor中併發運行,數據可能都不在一臺機器上,一個stage處理的rdd都是一樣的
              task怎麼拿到自己要處理的數據的?
       答案:通過broadcast value  廣播變量獲取
     */
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 拿到shuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 拿到shuffleWriter
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)

      // 首先,調用rdd的iterator方法,並且傳入了當前要處理的partition
      // 核心邏輯就在rdd的iterator()方法中
      // 執行完成rdd之後,rdd或返回處理過後的partition數據,這些數據通過shuffleWriter
      // 在經過HashPartitioner寫入對應的分區中
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

      // 返回結果 MapStatus ,裏面封裝了ShuffleMapTask存儲在哪裏,其實就是BlockManager相關信息
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }
  ...
}

ResultTask.scala

  override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L
    // 調用自定義算子
    func(context, rdd.iterator(partition, context))
  }

RDD.scala

  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      // rdd partition的計算
      computeOrReadCheckpoint(split, context)
    }
  }
  private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }
  @DeveloperApi
  def compute(split: Partition, context: TaskContext): Iterator[T]

MapPartitionRDD.scala

  /**
    * 對rdd的一個partition執行自定義的算子(函數)
    *
    * f : spark封裝了用戶自定義的算子,裏面還有一些其他邏輯,即可理解rdd對partition執行了自定義的算子
    *
    * 返回新的rdd的partition數據 (其實transform操作,就是有新的rdd的partition的數據返回)
    */
  override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

CoraseGrainedExecutorBackend.scala

  override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
    // 向sparkDeployBackend/standaloneBackend發送消息
    val msg = StatusUpdate(executorId, taskId, state, data)
    driver match {
      case Some(driverRef) => driverRef.send(msg)
      case None => logWarning(s"Drop $msg because has not yet connected to driver")
    }
  }

CoraseGrainedSchdulerBackend.scala

override def receive: PartialFunction[Any, Unit] = {

      // 本版本是spark 2.X 與spark 1.X在這裏的處理有一定的區別
      // spark 1.X 是以taskset來處理的 spark 2.x 以executor來處理

      // 處理task執行結束的事件
      case StatusUpdate(executorId, taskId, state, data) =>
        //因爲實際的spark程序可能因爲各種各樣的原因執行失敗 task lost
        // 方法中將會移除executor,將task加入失敗隊列
        scheduler.statusUpdate(taskId, state, data.value)

        // 如果task最終執行結束
        if (TaskState.isFinished(state)) {

          // 獲取到executor對應的數據, 判斷是否還有需要執行的任務
          //  executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              // 釋放cup資源
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              // 如果還存在需要執行的任務,繼續跳任務
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                s"from unknown executor with ID $executorId")
          }
        }
private def makeOffers(executorId: String) {
      // Filter out executors under killing
      if (executorIsAlive(executorId)) {
        val executorData = executorDataMap(executorId)
        val workOffers = IndexedSeq(
          new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
        launchTasks(scheduler.resourceOffers(workOffers))
      }
    }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章