Spark Codegen原理分析

1、背景

Spark Codegen是在CBO&RBO後,將算子的底層邏輯用代碼來實現的一種優化。
具體包括Expression級別和WholeStage級別的Codegen。

2、舉例說明

① Expression級別:摘一個網上的例子:x + (1 + 2)

用scala代碼表示:

Add(Attribute(x), Add(Literal(1), Literal(2)))

語法樹如下:
在這裏插入圖片描述
遞歸求值這棵語法樹的常規代碼如下:

tree.transformUp {
  case Attribute(idx) => Literal(row.getValue(idx))
  case Add(Literal(c1),Literal(c2)) => Literal(c1+c2)
  case Literal(c) => Literal(c)
}

執行上述代碼需要做很多類型匹配、虛函數調用、對象創建等額外邏輯,這些overhead遠超對表達式求值本身。
爲了消除這些overhead,Spark Codegen直接拼成求值表達式的java代碼並進行即時編譯。具體分爲三個步驟:

  1. 代碼生成。根據語法樹生成java代碼,封裝在wrapper類中:
... // class wrapper
row.getValue(idx) + (1 + 2)
... // class wrapper
  1. 即時編譯。使用Janino框架把生成代碼編譯成class文件。
  2. 加載執行。最後加載並執行。

優化前後性能有數量級的提升。
在這裏插入圖片描述

② WholeStage級別 舉一個稍微複雜的例子,並詳細分析一下

看了上面的例子,應該先有了一個大致的印象,接下來看下:
SQL:

select * from test.zyz where id=1;

表結構:

CREATE TABLE `test.zyz`(
  `id` int,
  `name` string)
PARTITIONED BY (
  `pt` int)
ROW FORMAT SERDE
  'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
STORED AS INPUTFORMAT
  'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
OUTPUTFORMAT
  'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'

執行計劃:

Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*(1) Project [id#11, name#12, pt#13]
+- *(1) Filter (isnotnull(id#11) && (id#11 = 1))
   +- *(1) FileScan parquet test.zyz[id#11,name#12,pt#13] Batched: true, Format: Parquet, Location: CatalogFileIndex[hdfs://nameservice1/user/hive/warehouse/test.db/zyz], PartitionCount: 1, PartitionFilters: [], PushedFilters: [IsNotNull(id), EqualTo(id,1)], ReadSchema: struct<id:int,name:string>

Codegen結果:

public Object generate(Object[] references) {
        return new GeneratedIteratorForCodegenStage1(references);
    }

    // codegenStageId=1
    final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
        private Object[] references;
        private scala.collection.Iterator[] inputs;
        // SQLMetrics 記錄執行耗時
        private long scan_scanTime_0;
        // batch index
        private int scan_batchIdx_0;
        // 一共3個字段
        private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] scan_mutableStateArray_2 = new org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[3];
        private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] scan_mutableStateArray_3 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
        // 一共1個分區
        private org.apache.spark.sql.vectorized.ColumnarBatch[] scan_mutableStateArray_1 = new org.apache.spark.sql.vectorized.ColumnarBatch[1];
        private scala.collection.Iterator[] scan_mutableStateArray_0 = new scala.collection.Iterator[1];

        public GeneratedIteratorForCodegenStage1(Object[] references) {
            this.references = references;
        }

        public void init(int index, scala.collection.Iterator[] inputs) {
            partitionIndex = index;
            this.inputs = inputs;
            scan_mutableStateArray_0[0] = inputs[0];

            // 初始化3個字段,大小32bytes
            scan_mutableStateArray_3[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);
            scan_mutableStateArray_3[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);
            scan_mutableStateArray_3[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 32);

        }

        private void scan_nextBatch_0() throws java.io.IOException {
            long getBatchStart = System.nanoTime();
            // 掃描下一批數據
            if (scan_mutableStateArray_0[0].hasNext()) {
                scan_mutableStateArray_1[0] = (org.apache.spark.sql.vectorized.ColumnarBatch)scan_mutableStateArray_0[0].next();
                ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(scan_mutableStateArray_1[0].numRows());
                scan_batchIdx_0 = 0;
                scan_mutableStateArray_2[0] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(0);
                scan_mutableStateArray_2[1] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(1);
                scan_mutableStateArray_2[2] = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) scan_mutableStateArray_1[0].column(2);
            }
            // 掃描一個batch計時
            scan_scanTime_0 += System.nanoTime() - getBatchStart;
        }

        protected void processNext() throws java.io.IOException {
            // 沒有數據,繼續掃描
            if (scan_mutableStateArray_1[0] == null) {
                scan_nextBatch_0();
            }
            // 有數據就開始處理
            while (scan_mutableStateArray_1[0] != null) {
                // 獲取這批數據行數並計算出末位
                int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
                int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
                // 遍歷每行數據
                for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
                    int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
                    do {
                        // 判斷非空
                        boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
                        int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));

                        if (!(!scan_isNull_0)) continue;

                        boolean filter_value_2 = false;
                        filter_value_2 = scan_value_0 == 1;
                        if (!filter_value_2) continue;

                        ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);

                        // 非空獲取value
                        boolean scan_isNull_1 = scan_mutableStateArray_2[1].isNullAt(scan_rowIdx_0);
                        UTF8String scan_value_1 = scan_isNull_1 ? null : (scan_mutableStateArray_2[1].getUTF8String(scan_rowIdx_0));
                        boolean scan_isNull_2 = scan_mutableStateArray_2[2].isNullAt(scan_rowIdx_0);
                        int scan_value_2 = scan_isNull_2 ? -1 : (scan_mutableStateArray_2[2].getInt(scan_rowIdx_0));
                        scan_mutableStateArray_3[2].reset();

                        scan_mutableStateArray_3[2].zeroOutNullBytes();

                        if (false) {
                            scan_mutableStateArray_3[2].setNullAt(0);
                        } else {
                            scan_mutableStateArray_3[2].write(0, scan_value_0);
                        }

                        if (scan_isNull_1) {
                            scan_mutableStateArray_3[2].setNullAt(1);
                        } else {
                            scan_mutableStateArray_3[2].write(1, scan_value_1);
                        }

                        if (scan_isNull_2) {
                            scan_mutableStateArray_3[2].setNullAt(2);
                        } else {
                            scan_mutableStateArray_3[2].write(2, scan_value_2);
                        }
                        // 獲取三個字段value,合併輸出
                        append((scan_mutableStateArray_3[2].getRow()));

                    } while(false);
                    // shouldStop判斷了next是否還有數據,沒有就return
                    if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
                }
                scan_batchIdx_0 = scan_numRows_0;
                scan_mutableStateArray_1[0] = null;
                scan_nextBatch_0();
            }
            // Metric記錄耗時
            ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
            scan_scanTime_0 = 0;
        }
    }

3、Spark Codegen框架

Spark Codegen框架有三個核心組成部分

① 核心接口/類

1.CodegenSupport(接口)
實現該接口的Operator可以將自己的邏輯拼成java代碼。重要方法:

produce() // 輸出本節點產出Row的java代碼
consume() // 輸出本節點消費上游節點輸入的Row的java代碼

實現類包括但不限於: ProjectExec, FilterExec, HashAggregateExec, SortMergeJoinExec。
2.WholeStageCodegenExec(類)
CodegenSupport的實現類之一,Stage內部所有相鄰的實現CodegenSupport接口的Operator的融合,產出的代碼把所有被融合的Operator的執行邏輯封裝到一個Wrapper類中,該Wrapper類作爲Janino即時compile的入參。
3.InputAdapter(類)
CodegenSupport的實現類之一,膠水類,用來連接WholeStageCodegenExec節點和未實現CodegenSupport的上游節點。
4.BufferedRowIterator(接口)
WholeStageCodegenExec生成的java代碼的父類,重要方法:

public InternalRow next() // 返回下一條Row
public void append(InternalRow row) // append一條Row
② CodegenContext

管理生成代碼的核心類。主要涵蓋以下功能:

1.命名管理。保證同一Scope內無變量名衝突。
2.變量管理。維護類變量,判斷變量類型(應該聲明爲獨立變量還是壓縮到類型數組中),維護變量初始化邏輯等。
3.方法管理。維護類方法。
4.內部類管理。維護內部類。
5.相同表達式管理。維護相同子表達式,避免重複計算。
6.size管理。避免方法、類size過大,避免類變量數過多,進行比較拆分。如把表達式塊拆分成多個函數;把函數、變量定義拆分到多個內部類。
7.依賴管理。維護該類依賴的外部對象,如Broadcast對象、工具對象、度量對象等。
8.通用模板管理。提供通用代碼模板,如genComp, nullSafeExec等。

③ Produce-Consume Pattern

相鄰Operator通過Produce-Consume模式生成代碼。
Produce生成整體處理的框架代碼,例如aggregation生成的代碼框架如下:

if (!initialized) {
  # create a hash map, then build the aggregation hash map
  # call child.produce()
  initialized = true;
}
while (hashmap.hasNext()) {
  row = hashmap.next();
  # build the aggregation results
  # create variables for results
  # call consume(), which will call parent.doConsume()
   if (shouldStop()) return;
}

在這裏插入圖片描述

4、參考:

Spark Codegen淺析 https://mp.weixin.qq.com/s/77hSyE-Tcf9VKiWLeeMWKQ

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章