Spark的Checkpoint源码和机制

Spark的Checkpoint源码和机制

1 Overview

A checkpoint creates a known good point from which the SQL Server Database Engine can start applying changes contained in the log during recovery after an unexpected shutdown or crash.

在流式计算里,需要高容错的机制来确保程序的稳定和健壮。从源码中看看,在 Spark 中,Checkpoint 到底做了什么。在源码中搜索,可以在 Streaming 包中的 Checkpoint

作为 Spark 程序的入口,我们首先关注一下 SparkContext 里关于 Checkpoint 是怎么写的。SparkContext 我们知道,定义了很多 Spark 内部的对象的引用。可以找到 Checkpoint 的文件夹路径是这么定义的。

我们从一段简单的代码开始看一下checkpoint 。spark版本2.1.1

  val sparkConf = new SparkConf().setAppName("streaming").setMaster("local[*]")
    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
    val sc = sparkSession.sparkContext
    val ssc = new StreamingContext(sc,Seconds(5))
    //设置检查点目录
    ssc.checkpoint("./streaming_checkpoint")

看一下ssc.checkpoint

  def checkpoint(directory: String) {
    if (directory != null) {
      val path = new Path(directory)
      val fs = path.getFileSystem(sparkContext.hadoopConfiguration)
      fs.mkdirs(path)
      val fullPath = fs.getFileStatus(path).getPath().toString
      sc.setCheckpointDir(fullPath)
      checkpointDir = fullPath
    } else {
      checkpointDir = null
    }
  }

看一下sc.setCheckpointDir

// 定义 checkpointDir
private[spark] var checkpointDir: Option[String] = None
/**

Set the directory under which RDDs are going to be checkpointed. The directory must
be a HDFS path if running on a cluster.
*/
def setCheckpointDir(directory: String) {
// If we are running on a cluster, log a warning if the directory is local.
// Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
// its own local file system, which is incorrect because the checkpoint files
// are actually on the executor machines.
// 如果运行的是 cluster 模式,当设置本地文件夹的时候,会报 warning
// 道理很简单,被创建出来的文件夹路径实际上是 executor 本地的文件夹路径,不是不行,
// 只是有点不合理,Checkpoint 的东西最好还是放在分布式的文件系统中
if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
logWarning("Spark is not running in local mode, therefore the checkpoint directory " +
s"must not be on the local filesystem. Directory ‘$directory’ " +
“appears to be on the local filesystem.)
}

checkpointDir = Option(directory).map { dir =>
// 显然文件夹名就是 UUID.randoUUID() 生成的
val path = new Path(dir, UUID.randomUUID().toString)
val fs = path.getFileSystem(hadoopConfiguration)
fs.mkdirs(path)
fs.getFileStatus(path).getPath.toString
}
}

关于 setCheckpointDir 被那些类调用了,可以看以下截图。除了常见的 StreamingContext 中需要使用(因为容错性是流式计算的基本保证),另外的就是一些需要反复迭代计算使用 RDD 的场景,包括各种机器学习算法的时候,图中可以看到像 ALS, Decision Tree 等等算法,这些算法往往需要反复使用 RDD,遇到大的数据集用 Cache 就没有什么意义了,所以一般会用 Checkpoint。

此处我只计划深挖一下 spark core 里的代码。推荐大家一个 IDEA 的功能,下图右下方可以将你搜索的关键词的代码输出到外部文件中,到时候可以打开自己看看 spark core 中关于 Checkpoint 的代码是怎么组织的。

继续找找 Checkpoint 的相关信息,可以看到 runJob 方法的最后是一个 rdd.toCheckPoint() 的使用。runJob 我们知道是触发 action 的一个方法,那么我们进入 doCheckpoint() 看看。

/**
 * Run a function on a given set of partitions in an RDD and pass the results to the given
 * handler function. This is the main entry point for all actions in Spark.
 */
def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    resultHandler: (Int, U) => Unit): Unit = {
  if (stopped.get()) {
    throw new IllegalStateException("SparkContext has been shutdown")
  }
  val callSite = getCallSite
  val cleanedFunc = clean(func)
  logInfo("Starting job: " + callSite.shortForm)
  if (conf.getBoolean("spark.logLineage", false)) {
    logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
  }
  dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
  progressBar.foreach(_.finishAll())
  rdd.doCheckpoint()
}

 

然后基本就发现了 Checkpoint 的核心方法了。而 doCheckpoint()RDD 的私有方法,所以这里基本可以回答最开始提出的问题,我们在说 Checkpoint 的时候,到底是 Checkpoint 什么。答案就是 RDD。

private[spark] def doCheckpoint(): Unit = {
  RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) {
    // 该rdd是否已经调用doCheckpoint,如果还没有,则开始处理
    if (!doCheckpointCalled) {
      // 判断RDDCheckpointData是否已经定义了,如果已经定义了
      doCheckpointCalled = true
      if (checkpointData.isDefined) {
        // 查看是否需要把该rdd的所有依赖即血缘全部checkpoint
        if (checkpointAllMarkedAncestors) {
          // Linestage上的每一个rdd递归调用该方法
          dependencies.foreach(_.rdd.doCheckpoint())
        }
        // 调用RDDCheckpointData的checkpoint方法
        checkpointData.get.checkpoint()
      } else {
        dependencies.foreach(_.rdd.doCheckpoint())
      }
    }
  }
}

 

上面代码可以看到,需要判断一下一个变量 checkpointData 是否为空。那么它是这么被定义的。

private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None

然后看看 RDDCheckPointData 是个什么样的数据结构。

/**
 * This class contains all the information related to RDD checkpointing. Each instance of this
 * class is associated with an RDD. It manages process of checkpointing of the associated RDD,
 * as well as, manages the post-checkpoint state by providing the updated partitions,
 * iterator and preferred locations of the checkpointed RDD.
 */
private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
  extends Serializable {
  import CheckpointState._
  // The checkpoint state of the associated RDD.
  protected var cpState = Initialized
  // The RDD that contains our checkpointed data
  // 显然,这个就是被 Checkpoint 的 RDD 的数据
  private var cpRDD: Option[CheckpointRDD[T]] = None
  // TODO: are we sure we need to use a global lock in the following methods?
  /**
   * Return whether the checkpoint data for this RDD is already persisted.
   */
  def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
    cpState == Checkpointed
  }
  /**
   * Materialize this RDD and persist its content.
   * This is called immediately after the first action invoked on this RDD has completed.
   */
  final def checkpoint(): Unit = {
    // Guard against multiple threads checkpointing the same RDD by
    // atomically flipping the state of this RDDCheckpointData
    RDDCheckpointData.synchronized {
      if (cpState == Initialized) {
        cpState = CheckpointingInProgress
      } else {
        return
      }
    }
    val newRDD = doCheckpoint()
    // Update our state and truncate the RDD lineage
    // 可以看到 cpRDD 在此处被赋值,通过 newRDD 来生成,而生成的方法是 doCheckpointa()
    RDDCheckpointData.synchronized {
      cpRDD = Some(newRDD)
      cpState = Checkpointed
      rdd.markCheckpointed()
    }
  }
  /**
   * Materialize this RDD and persist its content.
   *
   * Subclasses should override this method to define custom checkpointing behavior.
   * @return the checkpoint RDD created in the process.
   */
   // 这个是 Checkpoint RDD 的抽象方法
  protected def doCheckpoint(): CheckpointRDD[T]
  /**
   * Return the RDD that contains our checkpointed data.
   * This is only defined if the checkpoint state is `Checkpointed`.
   */
  def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD }
  /**
   * Return the partitions of the resulting checkpoint RDD.
   * For tests only.
   */
  def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
    cpRDD.map(_.partitions).getOrElse { Array.empty }
  }
}

根据注释,可以知道这个类涵盖了 RDD Checkpoint 的所有信息。除了控制 Checkpoint 的过程,还会处理之后的状态变更。说到 Checkpoint 的状态变更,我们看看是如何定义的。

/**
 * Enumeration to manage state transitions of an RDD through checkpointing
 *
 * [ Initialized --{@literal >} checkpointing in progress --{@literal >} checkpointed ]
 */
private[spark] object CheckpointState extends Enumeration {
  type CheckpointState = Value
  val Initialized, CheckpointingInProgress, Checkpointed = Value
}

显然 Checkpoint 的过程分为初始化[Initialized] -> 正在 Checkpoint[CheckpointingInProgress] -> 结束 Checkpoint[Checkpointed] 三种状态。

图片.png

关于 RDDCheckpointData 有两个实现,分别分析一下。

  1. LocalRDDCheckpointData: RDD 会被保存到 Executor 本地文件系统中,以减少保存到分布式容错性文件系统的巨额开销,因此 Local 形式的 Checkpoint 是基于持久化来做的,没有写到外部分布式文件系统。
  2. ReliableRDDCheckpointData: Reliable 很好理解,就是把 RDD Checkpoint 到可依赖的文件系统,言下之意就是 Driver 重启的时候也可以从失败的时间点进行恢复,无需再走一次 RDD 的转换过程。

接着查看RDDCheckpointData的checkpoint方法,如下:

final def checkpoint(): Unit = {
  // 将checkpoint的状态从Initialized置为CheckpointingInProgress
  RDDCheckpointData.synchronized {
    if (cpState == Initialized) {
      cpState = CheckpointingInProgress
    } else {
      return
    }
  }
  // 调用子类的doCheckpoint,我们以ReliableCheckpointRDD为例,创建一个新的CheckpointRDD
  val newRDD = doCheckpoint()

  // 将checkpoint状态置为Checkpointed状态,并且改变rdd之前的依赖,设置父rdd为新创建的CheckpointRDD
  RDDCheckpointData.synchronized {
    cpRDD = Some(newRDD)
    cpState = Checkpointed
    rdd.markCheckpointed()
  }
}

1.1 LocalRDDCheckpointData

LocalRDDCheckpointData 中的核心方法 doCheckpoint()。需要保证 RDD 用了 useDisk 级别的持久化。需要运行一个 Spark 任务来重新构建这个 RDD。最终 new 一个 LocalCheckpointRDD 实例。

/**
 * Ensure the RDD is fully cached so the partitions can be recovered later.
 */
protected override def doCheckpoint(): CheckpointRDD[T] = {
  val level = rdd.getStorageLevel
// Assume storage level uses disk; otherwise memory eviction may cause data loss
assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing")

// Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
// must cache any missing partitions. TODO: avoid running another job here (SPARK-8582).
val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator)
val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
!SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i))
}
if (missingPartitionIndices.nonEmpty) {
rdd.sparkContext.runJob(rdd, action, missingPartitionIndices)
}

new LocalCheckpointRDDT
}

1.2 ReliableRDDCheckpointData

这个是写外部文件系统的 Checkpoint 类。

/**
 * Materialize this RDD and write its content to a reliable DFS.
 * This is called immediately after the first action invoked on this RDD has completed.
 */
protected override def doCheckpoint(): CheckpointRDD[T] = {
  val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean(“spark.cleaner.referenceTracking.cleanCheckpoints”, false)) {
rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}

logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
newRDD
}

可以看到核心方法是通过 ReliableCheckpointRDD.writeRDDToCheckpointDirectory() 来写 newRDD。这个方法代码逻辑非常清晰,同样是起一个 Spark 任务把 RDD 生成之后按 Partition 来写到文件系统中。

def writeRDDToCheckpointDirectory[T: ClassTag](
    originalRDD: RDD[T],
    checkpointDir: String,
    blockSize: Int = -1): ReliableCheckpointRDD[T] = {
  val checkpointStartTimeNs = System.nanoTime()

  val sc = originalRDD.sparkContext

  // Create the output path for the checkpoint
  // 创建checkpoint输出目录
  val checkpointDirPath = new Path(checkpointDir)
  // 获取HDFS文件系统API接口
  val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
  // 创建目录
  if (!fs.mkdirs(checkpointDirPath)) {
    throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
  }

  // Save to file, and reload it as an RDD
  // 将配置文件信息广播到所有节点
  val broadcastedConf = sc.broadcast(
    new SerializableConfiguration(sc.hadoopConfiguration))
  // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
  // 重新启动一个job,将rdd的分区数据写入HDFS
  sc.runJob(originalRDD,
    writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
  // 如果rdd的partitioner不为空,则将partitioner写入checkpoint目录
  if (originalRDD.partitioner.nonEmpty) {
    writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
  }

  val checkpointDurationMs =
    TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
  logInfo(s"Checkpointing took $checkpointDurationMs ms.")
  // 创建一个CheckpointRDD,该分区数目应该和原始的rdd的分区数是一样的
  val newRDD = new ReliableCheckpointRDD[T](
    sc, checkpointDirPath.toString, originalRDD.partitioner)
  if (newRDD.partitions.length != originalRDD.partitions.length) {
    throw new SparkException(
      s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
        s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
  }
  newRDD
}

2 Checkpoint尝试

Spark 的 Checkpoint 机制通过上文在源码上分析了一下,那么也可以在 Local 模式下实践一下。利用 spark-shell 来简单尝试一下就好了。

scala> val data = sc.parallelize(List(1, 2, 3))
data: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[0] at parallelize at <console>:24
scala> sc.setCheckpointDir("/tmp")
scala> data.checkpoint
scala> data.count
res2: Long = 3

 

从以上代码示例上可以看到,首先构建一个 rdd,并且设置 Checkpoint 文件夹,因为是 Local 模式,所以可以设定本地文件夹做尝试。

# list 一下 /tmp 目录,发现 Checkpoint 的文件夹</span>/tmp ls
73d8442e-a375-401c-b1fc-84284e25b89c

<span class="hljs-comment"><span class="hljs-comment"># tree 一下 Checkpoint 文件夹看看是什么结构的,可以看到默认构建的 rdd 四个分区都被 checkpoint 了</span>/tmp tree 73d8442e-a375-401c-b1fc-84284e25b89c
73d8442e-a375-401c-b1fc-84284e25b89c
└── rdd-0
    ├── part-00000
    ├── part-00001
    ├── part-00002
    └── part-00003

1 directory, 4 files

=============================================

读取checkpoint原理分析

详细分析
Spark RDD主要由Dependency、Partition、Partitioner组成,Partition是其中之一。一份待处理的原始数据会被按照相应的逻辑(例如jdbc和hdfs的split逻辑)切分成n份,每份数据对应到RDD中的一个Partition,Partition的数量决定了task的数量,影响着程序的并行度。

**1. ** 我们从Partition入手,Partition源码如下:

/**
 * An identifier for a partition in an RDD.
 */
trait Partition extends Serializable {
  /** Get the partition's index within its parent RDD*/
  def index: Int

  // A better default implementation of HashCode
  override def hashCode(): Int = index

  override def equals(other: Any): Boolean = super.equals(other)
}

Partition的定义很简单。Partition和RDD是伴生的,所以每一种RDD都有其对应的Partition实现,所以,分析Partition主要是分析其子类。

  1. 在RDD.scala中,定义了很多方法,如:
//输入一个partition,对其代表的数据进行计算
@DeveloperApi
def compute(split: Partition, context: TaskContext): Iterator[T]
//数据如何被split的逻辑
protected def getPartitions: Array[Partition]
//这个RDD的依赖——它的父RDD
protected def getDependencies: Seq[Dependency[_]] = deps
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
@transient val partitioner: Option[Partitioner] = None

其中的第二个方法,getPartitions()是数据源如何被切分的逻辑,返回值正是Partition,第一个方法compute()是消费切割后的Partition的方法,我们从getPartitions和compute方法入手。

  1. RDD 是通过 iterator 来进行计算:每当 Task 运行的时候会调用 RDD 的 Compute 方法进行计算,而 Compute 方法会调用 iterator 方法。iterator()方法在RDD.scala中的源码如下:
  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      // 如果StorageLevel不为空,表示该RDD已经持久化过了,可能是在内存,也有可能是在磁盘,
      // 如果是磁盘获取的,需要把block缓存在内存中
      getOrCompute(split, context)
    } else {
      // 进行rdd partition的计算或者根据checkpoint读取数据
      computeOrReadCheckpoint(split, context)
    }
  }
  1. 这个方法是 final 级别【不能覆写但可以被子类去使用】,先看持久化的逻辑,我们可以看getOrCompute方法,这个方法从内存或者磁盘获取,如果从磁盘获取需要将block缓存到内存:
  private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
      
    val blockId = RDDBlockId(id, partition.index)  // 根据rdd id创建RDDBlockId
      
    var readCachedBlock = true                     // 是否从缓存的block读取
      
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
        
      // 如果执行到了这,说明没有获取到block,readCachedBlock设置成false,表示不能从cache中读取。
      readCachedBlock = false
      // 需要调用该函数重新计算或者从checkpoint读取
      computeOrReadCheckpoint(partition, context)
        
    }) match {
      // 获取到了结果直接返回
      case Left(blockResult) =>   // 如果从cache读取block
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
      case Right(iter) =>
        new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
    }
  }
  1. 其中getOrElseUpdate方法做了什么:如果指定的block存在,则直接获取,否则调用makeIterator方法去计算block,然后持久化最后返回值,代码如下:
 def getOrElseUpdate[T](
      blockId: BlockId,
      level: StorageLevel,
      classTag: ClassTag[T],
      makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
    // get方法 尝试从本地获取数据,如果获取不到则从远端获取
    get[T](blockId)(classTag) match {
      case Some(block) =>
        return Left(block)
      case _ =>
        // Need to compute the block.
    }
    // 如果本地化和远端都没有获取到数据,则调用makeIterator计算,最后将结果写入block
    doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
      case None =>    // 表示写入成功
        val blockResult = getLocalValues(blockId).getOrElse {      // 从本地获取数据块
          releaseLock(blockId)
          throw new SparkException(s"get() failed for block $blockId even though we held a lock")
        }
        releaseLock(blockId)
        Left(blockResult)
      case Some(iter) =>                 // 如果写入失败
        // 如果put操作失败,表示可能是因为数据太大,无法写入内存,又无法被磁盘drop,因此我们需要返回这个iterator给调用者以至于他们能够做出决定这个值是什么,怎么办
       Right(iter)
    }
  }
  1. 通过get方法获取数据,先从本地获取数据,如果没有则从远端获取,代码如下:
  def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
    // 从本地获取block
    val local = getLocalValues(blockId)
    // 如果本地获取到了则返回
    if (local.isDefined) {
      logInfo(s"Found block $blockId locally")
      return local
    }
    // 如果本地没有获取到则从远端获取
    val remote = getRemoteValues[T](blockId)
    // 如果远端获取到了则返回,没有返回None
    if (remote.isDefined) {
      logInfo(s"Found block $blockId remotely")
      return remote
    }
    None
  }
  1. 如何从本地获取block的逻辑在getLocalValues方法中,这个方法会从本地获取block,如果存在返回BlockResult,不存在返回None;如果storage level是磁盘,则还需将得到的block缓存到内存存储,方便下次读取,具体如下:
def getLocalValues(blockId: BlockId): Option[BlockResult] = {
  logDebug(s"Getting local block $blockId")
  // 调用block info manager,锁定该block,然后读取block,返回该block 元数据block info
  blockInfoManager.lockForReading(blockId) match {
    // 没有读取到则返回None
    case None =>
      logDebug(s"Block $blockId was not found")
      None
    // 读取到block元数据
    case Some(info) =>
      val level = info.level    // 获取存储级别storage level
      logDebug(s"Level for block $blockId is $level")
      val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
      if (level.useMemory && memoryStore.contains(blockId)) {   // 如果使用内存,且内存memory store包含这个block id
        // 判断是不是storage level是不是反序列化的,如果是反序列化的,则调用MemoryStore的getValues方法
        // 否则调用MemoryStore的getBytes然后反序列输入流返回数据作为迭代器
        val iter: Iterator[Any] = if (level.deserialized) {
          memoryStore.getValues(blockId).get
        } else {
          serializerManager.dataDeserializeStream(
            blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
        }
        val ci = CompletionIterator[Any, Iterator[Any]](iter, {
          releaseLock(blockId, taskAttemptId)
        })
        // 构建一个BlockResult对象返回,这个对象包括数据,读取方式以及字节大小
        Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
      } else if (level.useDisk && diskStore.contains(blockId)) {  // 如果使用磁盘存储,且disk store包含这个block则从磁盘获取,并且把结果放入内存
        val diskData = diskStore.getBytes(blockId)  // 调用DiskStore的getBytes方法,如果需要反序列化,则进行反序列
        val iterToReturn: Iterator[Any] = {
          if (level.deserialized) {
            val diskValues = serializerManager.dataDeserializeStream(
              blockId,
              diskData.toInputStream())(info.classTag)
            // 尝试将从磁盘读的溢写的值加载到内存,方便后续快速读取
            maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
          } else {
            // 如果不需要反序列化,首先将读取到的流加载到内存,方便后续快速读取
            val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData)
              .map { _.toInputStream(dispose = false) }
              .getOrElse { diskData.toInputStream() }
            // 然后再返回反序列化之后的数据
            serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
          }
        }
        // 构建BlockResult返回
        val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
          releaseLockAndDispose(blockId, diskData, taskAttemptId)
        })
        Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
      } else {
        // 处理本地读取block失败,报告driver这是一个无效的block,将会删除这个block
        handleLocalReadFailure(blockId)
      }
  }
}
  1. 远端读取,即从block所存放的其他block manager(其他节点)获取block,逻辑如下:
private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
  val ct = implicitly[ClassTag[T]]
  getRemoteBytes(blockId).map { 
      // 将远程fetch的结果进行反序列化,然后构建BlockResult返回
      data => val values =
      serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
    new BlockResult(values, DataReadMethod.Network, data.size)
  }
}

其中获取获取数据的方法getRemoteBytes,逻辑如下:

def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
  logDebug(s"Getting remote block $blockId")
  require(blockId != null, "BlockId is null")
  var runningFailureCount = 0
  var totalFailureCount = 0
  // 首先根据blockId获取当前block存在在哪些block manager上
  val locations = getLocations(blockId)
  // 最大允许的获取block的失败次数为该block对应的block manager数量
  val maxFetchFailures = locations.size
  var locationIterator = locations.iterator
  while (locationIterator.hasNext) { // 开始遍历block manager
    val loc = locationIterator.next()
    logDebug(s"Getting remote block $blockId from $loc")
    val data = try {
      // 通过调用BlockTransferSerivce的fetchBlockSync方法从远端获取block
      blockTransferService.fetchBlockSync(
        loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
    } catch {
      case NonFatal(e) =>
        runningFailureCount += 1
        totalFailureCount += 1
        // 如果总的失败数量大于了阀值则返回None
        if (totalFailureCount >= maxFetchFailures) {
          logWarning(s"Failed to fetch block after $totalFailureCount fetch failures. " +
            s"Most recent failure cause:", e)
          return None
        }

        logWarning(s"Failed to fetch remote block $blockId " +
          s"from $loc (failed attempt $runningFailureCount)", e)

        if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
          locationIterator = getLocations(blockId).iterator
          logDebug(s"Refreshed locations from the driver " +
            s"after ${runningFailureCount} fetch failures.")
          runningFailureCount = 0
        }
        null
    }

    // 成功的话,返回ChunkedByteBuffer
    if (data != null) {
      return Some(new ChunkedByteBuffer(data))
    }
    logDebug(s"The value of block $blockId is null")
  }
  logDebug(s"Block $blockId not found")
  None
}
  1. 另一个分支checkpiont

根据上面的iterator()的另一个分支:如果block没有被持久化,即storage level为None,我们就需要进行计算或者从Checkpoint读取数据;如果已经checkpoint了,则调用ietrator去读取block数据,否则调用Parent的RDD的compute方法。

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
    // 当前rdd是否已经checkpoint和物理化了,如果已经checkpoint,则调用第一个parent rdd的iterator方法获取
  if (isCheckpointedAndMaterialized) {
    firstParent[T].iterator(split, context)
  } else {
    //否则调用rdd的compute方法开始计算,返回一个Iterator对象
    compute(split, context)
  }
}

看一下ReliableRDDCheckpointData的compute实现方式

ReliableRDDCheckpointData
//读取与给定分区关联的检查点文件的内容。
  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))
    ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context)
  }
  
 /**
   * Read the content of the specified checkpoint file.
   * 读取指定检查点文件的内容。
   */
  def readCheckpointFile[T](
      path: Path,
      broadcastedConf: Broadcast[SerializableConfiguration],
      context: TaskContext): Iterator[T] = {
    val env = SparkEnv.get
    val fs = path.getFileSystem(broadcastedConf.value.value)
    val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
    val fileInputStream = fs.open(path, bufferSize)
    val serializer = env.serializer.newInstance()
    val deserializeStream = serializer.deserializeStream(fileInputStream)

    // Register an on-task-completion callback to close the input stream.
    context.addTaskCompletionListener(context => deserializeStream.close())

    deserializeStream.asIterator.asInstanceOf[Iterator[T]]
  }

看一下LocalCheckpointRDD的compute

//抛出异常,找不到checkpoint。只有在original RDD未被显式地持久化或一个executor丢失了才会有这样的情况。然而,在正常情况下,original RDD将完全缓存,因此应已计算所有分区,并且在块存储中可用。
  override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
    throw new SparkException(
      s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
      s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
      s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " +
      s"instead, which is slower than local checkpointing but more fault-tolerant.")
  }

3 Summary

检查点(本质是通过将RDD写入Disk做检查点)是为了通过lineage做容错的辅助。lineage过长会造成容错成本过高。这样就不如在中间阶段做检查点容错,假设之后有节点出现故障而丢失分区。从做检查点的RDD开始重做Lineage,就会降低开销。

建议:做检查点的RDD最好是已缓存在内存中,否则保存检查点的过程还需要重新计算,产生I/O开销。
checkpoint基础
checkpoint 可能遇到的问题参考这里
参考:https://www.jianshu.com/p/a75d0439c2f9
参考:https://blog.csdn.net/changshuchao/article/details/88634555
参考:https://www.cnblogs.com/small-k/p/8909942.html

本文只是学习用,很多都是从官网、个人博客摘的,如有问题请留言,我会处理的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章