[spark]Spark UDT with Codegen UDF

本文介紹自定義一種數據類型Point,並針對Point實現Add操作,並且該Add操作在codegen中實現


build.sbt

name := "PointUdt"

version := "0.1"

scalaVersion := "2.12.11"

libraryDependencies += "org.apache.spark" %% "spark-core" % "3.0.0-preview2"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.0-preview2"

PointUdtTest.scala

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._

package org.apache.spark.sql.udt.point {

  import org.apache.spark.sql.catalyst.InternalRow
  import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes
  import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}

  @SQLUserDefinedType(udt = classOf[PointUDT])
  class Point(val x: Double, val y: Double) extends Serializable {
    override def hashCode(): Int = 31 * (31 * x.hashCode()) + y.hashCode()

    override def equals(other: Any): Boolean = other match {
      case that: Point => this.x == that.x && this.y == that.y
      case _ => false
    }

    override def toString(): String = s"($x, $y)"
  }

  class PointUDT extends UserDefinedType[Point] {
    override def sqlType: DataType = ArrayType(DoubleType, false)

    override def serialize(obj: Point): GenericArrayData = {
      val output = new Array[Double](2)
      output(0) = obj.x
      output(1) = obj.y
      new GenericArrayData(output)
    }

    override def deserialize(datum: Any): Point = {
      datum match {
        case values: ArrayData => new Point(values.getDouble(0), values.getDouble(1))
      }
    }

    override def userClass: Class[Point] = classOf[Point]
  }

  case class Add1(inputExpr: Seq[Expression]) extends Expression with ExpectsInputTypes with CodegenFallback {
    override def nullable: Boolean = false

    override def eval(input: InternalRow): Any = {
      val left = inputExpr(0).eval(input).asInstanceOf[ArrayData]
      val right = inputExpr(1).eval(input).asInstanceOf[ArrayData]
      val result_x = left.getDouble(0) + right.getDouble(0)
      val result_y = left.getDouble(1) + right.getDouble(1)
      val result = new Point(result_x, result_y)
      new PointUDT().serialize(result)
    }

    override def dataType: DataType = new PointUDT

    override def inputTypes: Seq[AbstractDataType] = Seq(new PointUDT, new PointUDT)

    override def children: Seq[Expression] = inputExpr
  }

  case class Add2(inputExpr: Seq[Expression]) extends Expression {
    import org.apache.spark.sql.catalyst.expressions.codegen._
    import org.apache.spark.sql.catalyst.expressions.codegen.Block._

    override def nullable: Boolean = false

    override def eval(input: InternalRow): Any = ???

    override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
      val left_code = inputExpr(0).genCode(ctx)
      val right_code = inputExpr(1).genCode(ctx)
      ev.copy(code =
        code"""
            ${left_code.code}
            ${right_code.code}

            ${CodeGenerator.javaType(DoubleType)} ${ev.value}_x = ${left_code.value}.getDouble(0) + ${right_code.value}.getDouble(0);
          ${CodeGenerator.javaType(DoubleType)} ${ev.value}_y = ${left_code.value}.getDouble(1) + ${right_code.value}.getDouble(1);
          org.apache.spark.sql.udt.point.Point ${ev.value}_p = new org.apache.spark.sql.udt.point.Point(${ev.value}_x, ${ev.value}_y);
          org.apache.spark.sql.udt.point.PointUDT ${ev.value}_u = new org.apache.spark.sql.udt.point.PointUDT();
          ${CodeGenerator.javaType(ArrayType(DoubleType,false))} ${ev.value} = ${ev.value}_u.serialize(${ev.value}_p);
            """,FalseLiteral)
    }

    override def dataType: DataType = new PointUDT

    override def children: Seq[Expression] = inputExpr
  }

}

object PointUdtTest {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("point_udt_test")
      .getOrCreate()

    import org.apache.spark.sql.udt.point._

    val data = Seq(
      Row(1, new Point(1, 1), new Point(10, 10)),
      Row(2, new Point(2, 2), new Point(20, 20)),
      Row(3, new Point(3, 3), new Point(30, 30)),
      Row(4, new Point(4, 4), new Point(40, 40)),
      Row(5, new Point(5, 5), new Point(50, 50))
    )

    val rdd_d = spark.sparkContext.parallelize(data)
    val schema = StructType(Array(StructField("idx", IntegerType, false), StructField("point1", new PointUDT, false), StructField("point2", new PointUDT, false)))
    val df = spark.createDataFrame(rdd_d, schema)
    df.createOrReplaceTempView("data")

    spark.sessionState.functionRegistry.createOrReplaceTempFunction("add1",Add1)
    spark.sessionState.functionRegistry.createOrReplaceTempFunction("add2",Add2)

    var rst = spark.sql("select * from data")
    rst.queryExecution.debug.codegen()
    rst.show()

    rst = spark.sql("select idx, add1(point1, point2) from data ")
    rst.explain()
    rst.queryExecution.debug.codegen()
    rst.show()

    rst = spark.sql("select idx, add2(point1, point2) from data ")
    rst.explain()
    rst.queryExecution.debug.codegen()
    rst.show()

    spark.stop()
  }
}

程序輸入如下

Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 (maxMethodCodeSize:320; maxConstantPoolSize:135(0.21% used); numInnerClasses:0) ==
*(1) Scan ExistingRDD[idx#3,point1#4,point2#5]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator rdd_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] rdd_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[] rdd_mutableStateArray_1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[2];
/* 012 */
/* 013 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 014 */     this.references = references;
/* 015 */   }
/* 016 */
/* 017 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */     partitionIndex = index;
/* 019 */     this.inputs = inputs;
/* 020 */     rdd_input_0 = inputs[0];
/* 021 */     rdd_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 64);
/* 022 */     rdd_mutableStateArray_1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 023 */     rdd_mutableStateArray_1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 024 */
/* 025 */   }
/* 026 */
/* 027 */   protected void processNext() throws java.io.IOException {
/* 028 */     while ( rdd_input_0.hasNext()) {
/* 029 */       InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 030 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 031 */       int rdd_value_0 = rdd_row_0.getInt(0);
/* 032 */       ArrayData rdd_value_1 = rdd_row_0.getArray(1);
/* 033 */       ArrayData rdd_value_2 = rdd_row_0.getArray(2);
/* 034 */       rdd_mutableStateArray_0[0].reset();
/* 035 */
/* 036 */       rdd_mutableStateArray_0[0].write(0, rdd_value_0);
/* 037 */
/* 038 */       // Remember the current cursor so that we can calculate how many bytes are
/* 039 */       // written later.
/* 040 */       final int rdd_previousCursor_0 = rdd_mutableStateArray_0[0].cursor();
/* 041 */
/* 042 */       final ArrayData rdd_tmpInput_0 = rdd_value_1;
/* 043 */       if (rdd_tmpInput_0 instanceof UnsafeArrayData) {
/* 044 */         rdd_mutableStateArray_0[0].write((UnsafeArrayData) rdd_tmpInput_0);
/* 045 */       } else {
/* 046 */         final int rdd_numElements_0 = rdd_tmpInput_0.numElements();
/* 047 */         rdd_mutableStateArray_1[0].initialize(rdd_numElements_0);
/* 048 */
/* 049 */         for (int rdd_index_0 = 0; rdd_index_0 < rdd_numElements_0; rdd_index_0++) {
/* 050 */           rdd_mutableStateArray_1[0].write(rdd_index_0, rdd_tmpInput_0.getDouble(rdd_index_0));
/* 051 */         }
/* 052 */       }
/* 053 */
/* 054 */       rdd_mutableStateArray_0[0].setOffsetAndSizeFromPreviousCursor(1, rdd_previousCursor_0);
/* 055 */
/* 056 */       // Remember the current cursor so that we can calculate how many bytes are
/* 057 */       // written later.
/* 058 */       final int rdd_previousCursor_1 = rdd_mutableStateArray_0[0].cursor();
/* 059 */
/* 060 */       final ArrayData rdd_tmpInput_1 = rdd_value_2;
/* 061 */       if (rdd_tmpInput_1 instanceof UnsafeArrayData) {
/* 062 */         rdd_mutableStateArray_0[0].write((UnsafeArrayData) rdd_tmpInput_1);
/* 063 */       } else {
/* 064 */         final int rdd_numElements_1 = rdd_tmpInput_1.numElements();
/* 065 */         rdd_mutableStateArray_1[1].initialize(rdd_numElements_1);
/* 066 */
/* 067 */         for (int rdd_index_1 = 0; rdd_index_1 < rdd_numElements_1; rdd_index_1++) {
/* 068 */           rdd_mutableStateArray_1[1].write(rdd_index_1, rdd_tmpInput_1.getDouble(rdd_index_1));
/* 069 */         }
/* 070 */       }
/* 071 */
/* 072 */       rdd_mutableStateArray_0[0].setOffsetAndSizeFromPreviousCursor(2, rdd_previousCursor_1);
/* 073 */       append((rdd_mutableStateArray_0[0].getRow()));
/* 074 */       if (shouldStop()) return;
/* 075 */     }
/* 076 */   }
/* 077 */
/* 078 */ }


+---+----------+------------+
|idx|    point1|      point2|
+---+----------+------------+
|  1|(1.0, 1.0)|(10.0, 10.0)|
|  2|(2.0, 2.0)|(20.0, 20.0)|
|  3|(3.0, 3.0)|(30.0, 30.0)|
|  4|(4.0, 4.0)|(40.0, 40.0)|
|  5|(5.0, 5.0)|(50.0, 50.0)|
+---+----------+------------+

== Physical Plan ==
Project [idx#3, add1(point1#4, point2#5) AS add1(point1, point2)#28]
+- *(1) Scan ExistingRDD[idx#3,point1#4,point2#5]


Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 (maxMethodCodeSize:320; maxConstantPoolSize:135(0.21% used); numInnerClasses:0) ==
*(1) Scan ExistingRDD[idx#3,point1#4,point2#5]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator rdd_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] rdd_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[] rdd_mutableStateArray_1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[2];
/* 012 */
/* 013 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 014 */     this.references = references;
/* 015 */   }
/* 016 */
/* 017 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */     partitionIndex = index;
/* 019 */     this.inputs = inputs;
/* 020 */     rdd_input_0 = inputs[0];
/* 021 */     rdd_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 64);
/* 022 */     rdd_mutableStateArray_1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 023 */     rdd_mutableStateArray_1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 024 */
/* 025 */   }
/* 026 */
/* 027 */   protected void processNext() throws java.io.IOException {
/* 028 */     while ( rdd_input_0.hasNext()) {
/* 029 */       InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 030 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 031 */       int rdd_value_0 = rdd_row_0.getInt(0);
/* 032 */       ArrayData rdd_value_1 = rdd_row_0.getArray(1);
/* 033 */       ArrayData rdd_value_2 = rdd_row_0.getArray(2);
/* 034 */       rdd_mutableStateArray_0[0].reset();
/* 035 */
/* 036 */       rdd_mutableStateArray_0[0].write(0, rdd_value_0);
/* 037 */
/* 038 */       // Remember the current cursor so that we can calculate how many bytes are
/* 039 */       // written later.
/* 040 */       final int rdd_previousCursor_0 = rdd_mutableStateArray_0[0].cursor();
/* 041 */
/* 042 */       final ArrayData rdd_tmpInput_0 = rdd_value_1;
/* 043 */       if (rdd_tmpInput_0 instanceof UnsafeArrayData) {
/* 044 */         rdd_mutableStateArray_0[0].write((UnsafeArrayData) rdd_tmpInput_0);
/* 045 */       } else {
/* 046 */         final int rdd_numElements_0 = rdd_tmpInput_0.numElements();
/* 047 */         rdd_mutableStateArray_1[0].initialize(rdd_numElements_0);
/* 048 */
/* 049 */         for (int rdd_index_0 = 0; rdd_index_0 < rdd_numElements_0; rdd_index_0++) {
/* 050 */           rdd_mutableStateArray_1[0].write(rdd_index_0, rdd_tmpInput_0.getDouble(rdd_index_0));
/* 051 */         }
/* 052 */       }
/* 053 */
/* 054 */       rdd_mutableStateArray_0[0].setOffsetAndSizeFromPreviousCursor(1, rdd_previousCursor_0);
/* 055 */
/* 056 */       // Remember the current cursor so that we can calculate how many bytes are
/* 057 */       // written later.
/* 058 */       final int rdd_previousCursor_1 = rdd_mutableStateArray_0[0].cursor();
/* 059 */
/* 060 */       final ArrayData rdd_tmpInput_1 = rdd_value_2;
/* 061 */       if (rdd_tmpInput_1 instanceof UnsafeArrayData) {
/* 062 */         rdd_mutableStateArray_0[0].write((UnsafeArrayData) rdd_tmpInput_1);
/* 063 */       } else {
/* 064 */         final int rdd_numElements_1 = rdd_tmpInput_1.numElements();
/* 065 */         rdd_mutableStateArray_1[1].initialize(rdd_numElements_1);
/* 066 */
/* 067 */         for (int rdd_index_1 = 0; rdd_index_1 < rdd_numElements_1; rdd_index_1++) {
/* 068 */           rdd_mutableStateArray_1[1].write(rdd_index_1, rdd_tmpInput_1.getDouble(rdd_index_1));
/* 069 */         }
/* 070 */       }
/* 071 */
/* 072 */       rdd_mutableStateArray_0[0].setOffsetAndSizeFromPreviousCursor(2, rdd_previousCursor_1);
/* 073 */       append((rdd_mutableStateArray_0[0].getRow()));
/* 074 */       if (shouldStop()) return;
/* 075 */     }
/* 076 */   }
/* 077 */
/* 078 */ }


+---+--------------------+
|idx|add1(point1, point2)|
+---+--------------------+
|  1|        (11.0, 11.0)|
|  2|        (22.0, 22.0)|
|  3|        (33.0, 33.0)|
|  4|        (44.0, 44.0)|
|  5|        (55.0, 55.0)|
+---+--------------------+

== Physical Plan ==
*(1) Project [idx#3, add2(point1#4, point2#5) AS add2(point1, point2)#42]
+- *(1) Scan ExistingRDD[idx#3,point1#4,point2#5]


Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 (maxMethodCodeSize:282; maxConstantPoolSize:147(0.22% used); numInnerClasses:0) ==
*(1) Project [idx#3, add2(point1#4, point2#5) AS add2(point1, point2)#42]
+- *(1) Scan ExistingRDD[idx#3,point1#4,point2#5]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator rdd_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] rdd_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[] rdd_mutableStateArray_1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[3];
/* 012 */
/* 013 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 014 */     this.references = references;
/* 015 */   }
/* 016 */
/* 017 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */     partitionIndex = index;
/* 019 */     this.inputs = inputs;
/* 020 */     rdd_input_0 = inputs[0];
/* 021 */     rdd_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 64);
/* 022 */     rdd_mutableStateArray_1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 023 */     rdd_mutableStateArray_1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[0], 8);
/* 024 */     rdd_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32);
/* 025 */     rdd_mutableStateArray_1[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(rdd_mutableStateArray_0[1], 8);
/* 026 */
/* 027 */   }
/* 028 */
/* 029 */   protected void processNext() throws java.io.IOException {
/* 030 */     while ( rdd_input_0.hasNext()) {
/* 031 */       InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 032 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 033 */       int rdd_value_0 = rdd_row_0.getInt(0);
/* 034 */       ArrayData rdd_value_1 = rdd_row_0.getArray(1);
/* 035 */       ArrayData rdd_value_2 = rdd_row_0.getArray(2);
/* 036 */
/* 037 */       double project_value_1_x = rdd_value_1.getDouble(0) + rdd_value_2.getDouble(0);
/* 038 */       double project_value_1_y = rdd_value_1.getDouble(1) + rdd_value_2.getDouble(1);
/* 039 */       org.apache.spark.sql.udt.point.Point project_value_1_p = new org.apache.spark.sql.udt.point.Point(project_value_1_x, project_value_1_y);
/* 040 */       org.apache.spark.sql.udt.point.PointUDT project_value_1_u = new org.apache.spark.sql.udt.point.PointUDT();
/* 041 */       ArrayData project_value_1 = project_value_1_u.serialize(project_value_1_p);
/* 042 */       rdd_mutableStateArray_0[1].reset();
/* 043 */
/* 044 */       rdd_mutableStateArray_0[1].write(0, rdd_value_0);
/* 045 */
/* 046 */       // Remember the current cursor so that we can calculate how many bytes are
/* 047 */       // written later.
/* 048 */       final int project_previousCursor_0 = rdd_mutableStateArray_0[1].cursor();
/* 049 */
/* 050 */       final ArrayData project_tmpInput_0 = project_value_1;
/* 051 */       if (project_tmpInput_0 instanceof UnsafeArrayData) {
/* 052 */         rdd_mutableStateArray_0[1].write((UnsafeArrayData) project_tmpInput_0);
/* 053 */       } else {
/* 054 */         final int project_numElements_0 = project_tmpInput_0.numElements();
/* 055 */         rdd_mutableStateArray_1[2].initialize(project_numElements_0);
/* 056 */
/* 057 */         for (int project_index_0 = 0; project_index_0 < project_numElements_0; project_index_0++) {
/* 058 */           rdd_mutableStateArray_1[2].write(project_index_0, project_tmpInput_0.getDouble(project_index_0));
/* 059 */         }
/* 060 */       }
/* 061 */
/* 062 */       rdd_mutableStateArray_0[1].setOffsetAndSizeFromPreviousCursor(1, project_previousCursor_0);
/* 063 */       append((rdd_mutableStateArray_0[1].getRow()));
/* 064 */       if (shouldStop()) return;
/* 065 */     }
/* 066 */   }
/* 067 */
/* 068 */ }


+---+--------------------+
|idx|add2(point1, point2)|
+---+--------------------+
|  1|        (11.0, 11.0)|
|  2|        (22.0, 22.0)|
|  3|        (33.0, 33.0)|
|  4|        (44.0, 44.0)|
|  5|        (55.0, 55.0)|
+---+--------------------+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章