1、背景
Spark Codegen是在CBO&RBO後,將算子的底層邏輯用代碼來實現的一種優化。
具體包括Expression級別和WholeStage級別的Codegen。
2、舉例說明
① Expression級別:摘一個網上的例子:x + (1 + 2)
用scala代碼表示:
Add(Attribute(x), Add(Literal(1), Literal(2)))
語法樹如下:
遞歸求值這棵語法樹的常規代碼如下:
tree.transformUp {
case Attribute(idx) => Literal(row.getValue(idx))
case Add(Literal(c1),Literal(c2)) => Literal(c1+c2)
case Literal(c) => Literal(c)
}
執行上述代碼需要做很多類型匹配、虛函數調用、對象創建等額外邏輯,這些overhead遠超對表達式求值本身。
爲了消除這些overhead,Spark Codegen直接拼成求值表達式的java代碼並進行即時編譯。具體分爲三個步驟:
- 代碼生成。根據語法樹生成java代碼,封裝在wrapper類中:
... // class wrapper
row.getValue(idx) + (1 + 2)
... // class wrapper
- 即時編譯。使用Janino框架把生成代碼編譯成class文件。
- 加載執行。最後加載並執行。
優化前後性能有數量級的提升。
② WholeStage級別 舉一個稍微複雜的例子,並詳細分析一下
看了上面的例子,應該先有了一個大致的印象,接下來看下:
SQL:
select * from test.zyz where id=1;
表結構:
CREATE TABLE `test.zyz`(
`id` int,
`name` string)
PARTITIONED BY (
`pt` int)
ROW FORMAT SERDE
'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
STORED AS INPUTFORMAT
'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
OUTPUTFORMAT
'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
執行計劃:
Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*(1) Project [id#11, name#12, pt#13]
+- *(1) Filter (isnotnull(id#11) && (id#11 = 1))
+- *(1) FileScan parquet test.zyz[id#11,name#12,pt#13] Batched: true, Format: Parquet, Location: CatalogFileIndex[hdfs://nameservice1/user/hive/warehouse/test.db/zyz], PartitionCount: 1, PartitionFilters: [], PushedFilters: [IsNotNull(id), EqualTo(id,1)], ReadSchema: struct<id:int,name:string>
Codegen結果:
public Object generate(Object[] references) {
return new GeneratedIteratorForCodegenStage1(references);
}
// codegenStageId=1
final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
// SQLMetrics 記錄執行耗時
private long scan_scanTime_0;
// batch index
private int scan_batchIdx_0;
// 一共3個字段
private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] scan_mutableStateArray_2 = new org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[3];
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] scan_mutableStateArray_3 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
// 一共1個分區
private org.apache.spark.sql.vectorized.ColumnarBatch[] scan_mutableStateArray_1 = new org.apache.spark.sql.vectorized.ColumnarBatch[1];
private scala.collection.Iterator[] scan_mutableStateArray_0 = new scala.collection.Iterator[1];
public GeneratedIteratorForCodegenStage1(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
scan_mutableStateArray_0[0] = inputs[0];
// 初始化3個字段,大小32bytes
scan_mutableStateArray_3[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);
scan_mutableStateArray_3[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);
scan_mutableStateArray_3[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);
}
private void scan_nextBatch_0() throws java.io.IOException {
long getBatchStart = System.nanoTime();
// 掃描下一批數據
if (scan_mutableStateArray_0[0].hasNext()) {
scan_mutableStateArray_1[0] = (org.apache.spark.sql.vectorized.ColumnarBatch)scan_mutableStateArray_0[0].next();
((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(scan_mutableStateArray_1[0].numRows());
scan_batchIdx_0 = 0;
scan_mutableStateArray_2[0] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(0);
scan_mutableStateArray_2[1] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(1);
scan_mutableStateArray_2[2] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(2);
}
// 掃描一個batch計時
scan_scanTime_0 += System.nanoTime() - getBatchStart;
}
protected void processNext() throws java.io.IOException {
// 沒有數據,繼續掃描
if (scan_mutableStateArray_1[0] == null) {
scan_nextBatch_0();
}
// 有數據就開始處理
while (scan_mutableStateArray_1[0] != null) {
// 獲取這批數據行數並計算出末位
int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
// 遍歷每行數據
for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
do {
// 判斷非空
boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));
if (!(!scan_isNull_0)) continue;
boolean filter_value_2 = false;
filter_value_2 = scan_value_0 == 1;
if (!filter_value_2) continue;
((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
// 非空獲取value
boolean scan_isNull_1 = scan_mutableStateArray_2[1].isNullAt(scan_rowIdx_0);
UTF8String scan_value_1 = scan_isNull_1 ? null : (scan_mutableStateArray_2[1].getUTF8String(scan_rowIdx_0));
boolean scan_isNull_2 = scan_mutableStateArray_2[2].isNullAt(scan_rowIdx_0);
int scan_value_2 = scan_isNull_2 ? -1 : (scan_mutableStateArray_2[2].getInt(scan_rowIdx_0));
scan_mutableStateArray_3[2].reset();
scan_mutableStateArray_3[2].zeroOutNullBytes();
if (false) {
scan_mutableStateArray_3[2].setNullAt(0);
} else {
scan_mutableStateArray_3[2].write(0, scan_value_0);
}
if (scan_isNull_1) {
scan_mutableStateArray_3[2].setNullAt(1);
} else {
scan_mutableStateArray_3[2].write(1, scan_value_1);
}
if (scan_isNull_2) {
scan_mutableStateArray_3[2].setNullAt(2);
} else {
scan_mutableStateArray_3[2].write(2, scan_value_2);
}
// 獲取三個字段value,合併輸出
append((scan_mutableStateArray_3[2].getRow()));
} while(false);
// shouldStop判斷了next是否還有數據,沒有就return
if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
}
scan_batchIdx_0 = scan_numRows_0;
scan_mutableStateArray_1[0] = null;
scan_nextBatch_0();
}
// Metric記錄耗時
((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
scan_scanTime_0 = 0;
}
}
3、Spark Codegen框架
Spark Codegen框架有三個核心組成部分
① 核心接口/類
1.CodegenSupport(接口)
實現該接口的Operator可以將自己的邏輯拼成java代碼。重要方法:
produce() // 輸出本節點產出Row的java代碼
consume() // 輸出本節點消費上游節點輸入的Row的java代碼
實現類包括但不限於: ProjectExec, FilterExec, HashAggregateExec, SortMergeJoinExec。
2.WholeStageCodegenExec(類)
CodegenSupport的實現類之一,Stage內部所有相鄰的實現CodegenSupport接口的Operator的融合,產出的代碼把所有被融合的Operator的執行邏輯封裝到一個Wrapper類中,該Wrapper類作爲Janino即時compile的入參。
3.InputAdapter(類)
CodegenSupport的實現類之一,膠水類,用來連接WholeStageCodegenExec節點和未實現CodegenSupport的上游節點。
4.BufferedRowIterator(接口)
WholeStageCodegenExec生成的java代碼的父類,重要方法:
public InternalRow next() // 返回下一條Row
public void append(InternalRow row) // append一條Row
② CodegenContext
管理生成代碼的核心類。主要涵蓋以下功能:
1.命名管理。保證同一Scope內無變量名衝突。
2.變量管理。維護類變量,判斷變量類型(應該聲明爲獨立變量還是壓縮到類型數組中),維護變量初始化邏輯等。
3.方法管理。維護類方法。
4.內部類管理。維護內部類。
5.相同表達式管理。維護相同子表達式,避免重複計算。
6.size管理。避免方法、類size過大,避免類變量數過多,進行比較拆分。如把表達式塊拆分成多個函數;把函數、變量定義拆分到多個內部類。
7.依賴管理。維護該類依賴的外部對象,如Broadcast對象、工具對象、度量對象等。
8.通用模板管理。提供通用代碼模板,如genComp, nullSafeExec等。
③ Produce-Consume Pattern
相鄰Operator通過Produce-Consume模式生成代碼。
Produce生成整體處理的框架代碼,例如aggregation生成的代碼框架如下:
if (!initialized) {
# create a hash map, then build the aggregation hash map
# call child.produce()
initialized = true;
}
while (hashmap.hasNext()) {
row = hashmap.next();
# build the aggregation results
# create variables for results
# call consume(), which will call parent.doConsume()
if (shouldStop()) return;
}
4、參考:
Spark Codegen淺析 https://mp.weixin.qq.com/s/77hSyE-Tcf9VKiWLeeMWKQ