Spark記錄(四):Dataset.count()方法源碼剖析

因最近工作中涉及較多的Spark相關功能,所以趁週末閒來無事,研讀一下Dataset的count方法。Spark版本3.2.0

1、方法入口:

  def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
    plan.executeCollect().head.getLong(0)
  }

可以看到,count方法調用的是withAction方法,入參有三個:字符串count、調用方法獲取到的QueryExecution、一個函數。注:此處就是對Scala函數式編程的應用,將函數作爲參數來傳遞

2、第二個參數QueryExecution的獲取流程

 2.1、首先看groupBy()方法:

1   def groupBy(cols: Column*): RelationalGroupedDataset = {
2     RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
3   }

groupBy方法是用於分組聚合的,一般用法是groupBy之後加上agg聚合函數,對分組之後的每組數據進行聚合,入參爲Column類型的可變長度參數。

但上面count方法中調用時未傳任何入參,產生的效果就是****

groupBy方法只有一行代碼,生成並返回了一個RelationalGroupedDataset的對象,而且此處是用伴生對象的簡略寫法創建出來的,該行代碼其實質是調用了RelationalGroupedDataset的伴生對象中的apply方法,三個入參。

注:RelationalGroupedDataset 類是用於處理聚合操作的,內部封裝了對agg方法的處理,以及一些統計函數sum、max等的實現。

2.1.1、逐一看下RelationalGroupedDataset的三個入參:

首先是toDF()方法,方法體如下,可見就是重新創建了一個Dataset[Row]對象,即DataFrame

  def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema))

然後是cols.map(_.expr),即遍歷執行每個Column的expr表達式,因爲此處未傳入cols,故可忽略。

最後傳入的是 RelationalGroupedDataset.GroupByType,起了標識的作用。因爲RelationalGroupedDataset類的方法除了groupBy調用之外,還有Cube、Rollup、Pivot等都會調用,爲與其他幾種區別開,故傳入了GroupByType。

2.1.2、初探 RelationalGroupedDataset 類

apply方法:

1   def apply(
2       df: DataFrame,
3       groupingExprs: Seq[Expression],
4       groupType: GroupType): RelationalGroupedDataset = {
5     new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType)
6   }

類的定義:

class RelationalGroupedDataset protected[sql](
    private[sql] val df: DataFrame,
    private[sql] val groupingExprs: Seq[Expression],
    groupType: RelationalGroupedDataset.GroupType) {
......
}

可見沒有多餘的邏輯,只是單純的創建了一個對象。至於這個對象如何使用的,還需繼續追溯它裏面的count方法,即Dataset.count()中調用的groupBy().count()。

2.2、groupBy().count(),即 RelationalGroupedDataset.count():

  def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))

2.2.1、其中Alias(Count(Literal(1)).toAggregateExpression(), "count")的作用,就是生成 count(1) as count 這樣的一個統計函數的表達式。

2.2.2、然後toDF方法,如下所示:

 1 private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
 2     val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { // 是否保留分組的主鍵列,默認true
 3       groupingExprs match { // 若保留,則將分組的主鍵列拼到聚合表達式的前面
 4         // call `toList` because `Stream` can't serialize in scala 2.13
 5         case s: Stream[Expression] => s.toList ++ aggExprs
 6         case other => other ++ aggExprs
 7       }
 8     } else {
 9       aggExprs
10     }
11 
12     val aliasedAgg = aggregates.map(alias) // 處理設置別名的表達式
13 
14     groupType match {
15       case RelationalGroupedDataset.GroupByType =>
16         Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) // ***
17       case RelationalGroupedDataset.RollupType =>
18         Dataset.ofRows(
19           df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))),
20             aliasedAgg, df.logicalPlan))
21       case RelationalGroupedDataset.CubeType =>
22         Dataset.ofRows(
23           df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))),
24             aliasedAgg, df.logicalPlan))
25       case RelationalGroupedDataset.PivotType(pivotCol, values) =>
26         val aliasedGrps = groupingExprs.map(alias)
27         Dataset.ofRows(
28           df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
29     }
30   }

重點是第16行,進入ofRows方法中可以看到,其實就是又新建了一個Dataset[Row],並將加上count(1)表達式之後新生成的Aggregate執行計劃傳入。

至此,groupBy().count().queryExecution得到的就是一個count(1)的執行計劃了。

3、第三個參數,也是一個函數式參數:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

該參數入參是一個plan,返回值long類型,推測是獲取最終count值的,暫時放一放,後面調用到的時候再來研究。

4、看完三個參數,下面進入withAction方法:

1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
2     SQLExecution.withNewExecutionId(qe, Some(name)) {
3       qe.executedPlan.resetMetrics()
4       action(qe.executedPlan)
5     }
6   }

又是使用了科裏化傳參,第三個參數同樣是一個函數,在裏面調用了action這個函數參數。繼續追蹤withNewExecutionId方法:

 5、SQLExecution.withNewExecutionId

該方法代碼較多,下面先看一下它的主體結構。裏面省略的若干行代碼,實際是作爲一個函數參數傳入了withActive方法。

def withNewExecutionId[T](
      queryExecution: QueryExecution,
      name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive {
... // 省略若干代碼
}

而withActive方法如下,實際是將當前的SparkSession存入了本地線程變量中,方便後面的獲取。然後執行了函數block,而返回值就是外層withNewExecutionId方法中函數體的返回值。

private[sql] def withActive[T](block: => T): T = {
    val old = SparkSession.activeThreadSession.get()
    SparkSession.setActiveSession(this)
    try block finally {
      SparkSession.setActiveSession(old)
    }
  }

 下面回到外層的函數體:

5.1、SQLExecution.withNewExecutionId函數體第一部分

1     val sparkSession = queryExecution.sparkSession
2     val sc = sparkSession.sparkContext
3     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
4     val executionId = SQLExecution.nextExecutionId
5     sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
6     executionIdToQueryExecution.put(executionId, queryExecution)

先設置了一下executionId,該ID是一個線程安全的自增序列,每次加1,。設置給SparkContext之後,又將id與QueryExecution的映射關係存入Map中。

5.2、SQLExecution.withNewExecutionId函數體第二部分

第二部分主要是判斷若sql長度過長,需要進行截斷處理,無甚要點。

5.3、SQLExecution.withNewExecutionId函數體第三部分,代碼如下:

 1       withSQLConfPropagated(sparkSession) {
 2         var ex: Option[Throwable] = None
 3         val startTime = System.nanoTime()
 4         try {
 5           sc.listenerBus.post(SparkListenerSQLExecutionStart(
 6             executionId = executionId,
 7             description = desc,
 8             details = callSite.longForm,
 9             physicalPlanDescription = queryExecution.explainString(planDescriptionMode),
10             sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
11             time = System.currentTimeMillis()))
12           body
13         } catch {
14           case e: Throwable =>
15             ex = Some(e)
16             throw e
17         } finally {
18           val endTime = System.nanoTime()
19           val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
20           event.executionName = name
21           event.duration = endTime - startTime
22           event.qe = queryExecution
23           event.executionFailure = ex
24           sc.listenerBus.post(event)
25         }
26       }

起頭的 withSQLConfPropagated 方法,同樣還是科裏化的方式傳參,方法裏面將配置參數替換爲新的配置參數,執行完之後再將老參數存回去。

再然後是try裏面的一個post方法,finally裏面一個post方法,用於發送SQLExecution執行開始和結束的通知消息。

最後是核心函數調用,body。即前面一直引而未看的方法。

下面再返回頭來好好研究一下此處的body函數,函數體是:

{
      qe.executedPlan.resetMetrics()
      action(qe.executedPlan)
    }

qe變量即上面2.2中返回的groupBy().count().queryExecution

而action的函數體是:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

那麼內部具體是怎麼實現的呢?今天時間不早了,改日再搞它。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章