自定義UDF、UDAF、UDTF函數

注意事項:

1.udf、udaf函數的使用都需要使用sqlContext來創建function,如果是scala裏需要引用Java的方法或者函數的話,需要包裝一下,再寫個scala的方法,將Java的返回值輸出。

2.scala中的udf函數註冊park.sqlContext.udf.register("date_splits",date_splits _)

3.UDTF函數使用的時候,需要創建SparkSession對象,由SparkSession執行sql語句CREATE TEMPORARY FUNCTION myUDTF as '自己實現的UDTF位置’來創建
測試數據

1-174,"121.31583075,30.67559298","121.31583075,30.67784745","121.31848407,30.67784745","121.31848407,30.67559298"
1-175,"121.31848407,30.67559298","121.31848407,30.67784745","121.32113740000001,30.67784745","121.32113740000001,30.67559298"
1-176,"121.32113740000001,30.67559298","121.32113740000001,30.67784745","121.32379073,30.67784745","121.32379073,30.67559298"
1-177,"121.32379073,30.67559298","121.32379073,30.67784745","121.32644406,30.67784745","121.32644406,30.67559298"
1-178,"121.32644406,30.67559298","121.32644406,30.67784745","121.32909739,30.67784745","121.32909739,30.67559298"
1-179,"121.32909739,30.67559298","121.32909739,30.67784745","121.33175072,30.67784745","121.33175072,30.67559298"
1-180,"121.33175072,30.67559298","121.33175072,30.67784745","121.33440404,30.67784745","121.33440404,30.67559298"
1-181,"121.33440404,30.67559298","121.33440404,30.67784745","121.33705737,30.67784745","121.33705737,30.67559298"

本地測試代碼 

class BaseJob { }
case class LngLatSH1(id:String,lng1:Double,lat1:Double,lng2:Double,lat2:Double,lng3:Double,lat3:Double,lng4:Double,lat4:Double)
object BaseJob{
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("base_job")
      .enableHiveSupport()
        .master("local[2]")
      .config("spark.sql.warehouse.dir","/user/hive/warehouse")
      .config("spark.sql.shuffle.partitions",100)
      .getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")//設置日誌輸出級別
    spark.sql("CREATE TEMPORARY FUNCTION myFloatMap as 'hanghai.UDTF_TEST'")
//    spark.sqlContext.udf.register("myFloatMap",myFloatMap _)
    //這裏讀取的是hdfs上的文件
    val dataRDD: RDD[String] = sc.textFile("E:\\idea-workSpace\\wc\\src\\main\\scala\\hanghai\\test_v1.txt")
    val lineArrayRDD: RDD[Array[String]] = dataRDD.map(_.split(","))
    val LngLatSH1RDD: RDD[LngLatSH1] = lineArrayRDD.map(x=>LngLatSH1(x(0),x(1).split("\"")(1).toDouble,x(2).split("\"")(0).toDouble,
      x(3).split("\"")(1).toDouble,x(4).split("\"")(0).toDouble,x(5).split("\"")(1).toDouble,x(6).split("\"")(0).toDouble,x(7).split("\"")(1).toDouble,x(8).split("\"")(0).toDouble))
    import spark.implicits._
    val LngLatSH1DF: DataFrame = LngLatSH1RDD.toDF()
    LngLatSH1DF.createOrReplaceTempView("test")

    val data = spark.sql("select myFloatMap(lng1) from test").show(3)


    sc.stop()
    spark.stop()
  }
}

 

1.UDF函數 

import com.alibaba.fastjson.JSONObject;
import org.apache.spark.sql.api.java.UDF2;


public class MyUDF implements UDF2<String,String,String> {
    public String call(String o1, String o2) throws Exception {
        JSONObject jsonObject = JSONObject.parseObject(o1);
        return jsonObject.getString(o2);
    }
}

UDF測試類

import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.{DataTypes, StringType, StructField}
import org.apache.spark.sql.{Row, types}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{SparkConf, SparkContext}

object Test_UDF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("Test_UDF").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val ssc = new StreamingContext(sc, Seconds(5))
    val ds = ssc.socketTextStream("**.***.***.***", 8888)
    ds.foreachRDD {
      rdd =>
        val sqlContext = new HiveContext(sc);
        sqlContext.udf.register("udf", new MyUDF(), DataTypes.StringType);
        import sqlContext.implicits._
        val schema= types.StructType(
          Seq(
            StructField("id", StringType, true),
            StructField("info", StringType, true)
          )
        )
        val rowRDD = rdd.map(_.split(" ")).map(p=> Row(p(0),p(1)))
        val dataFrame = sqlContext.createDataFrame(rowRDD,schema)
        dataFrame.registerTempTable("test_udf")
        val dataFrame1 = sqlContext.sql("select id,udf(info,'username') from test_udf")
        dataFrame1.show();
    }
    ssc.start()
    ssc.awaitTermination()
  }
}

2.UDAF函數

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class MyUDAF extends UserDefinedAggregateFunction{
  override def inputSchema: StructType =
  {
    StructType(Array(StructField("str", StringType, true)))
  }

  override def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))
  }

  override def dataType: DataType = {
    IntegerType
  }

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit ={
    buffer(0) = 0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
}

UDAF測試類

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.types.{DataTypes, IntegerType, StringType, StructField}
import org.apache.spark.streaming.{Seconds, StreamingContext}

object Test_UDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("Test_UDAF").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val ssc = new StreamingContext(sc, Seconds(5))
    val ds = ssc.socketTextStream("**,***,***,***", 8888)
    ds.foreachRDD {
      rdd =>
        val sqlContext = new HiveContext(sc);
        sqlContext.udf.register("udaf", new MyUDAF);
        import sqlContext.implicits._
        val schema = types.StructType(
          Seq(
            StructField("name", StringType, true)
          )
        )
        val rowRDD = rdd.map(name => Row(name))
        val dataFrame = sqlContext.createDataFrame(rowRDD, schema)
        dataFrame.registerTempTable("test_udaf")
        dataFrame.show()
        val dataFrame1 = sqlContext.sql("select name,udaf(name) count from test_udaf group by name")
        dataFrame1.show();
    }
    ssc.start()
    ssc.awaitTermination()
  }

}

UDTF
UDTF需要注意以下幾點

不能再使用 sqlContext.udf.register方式來註冊自定義函數了,需要創建SparkSession對象,由SparkSession執行sql語句CREATE TEMPORARY FUNCTION myUDTF as '自己實現的UDTF位置’來創建
其次不能把註冊函數的語句寫在foreachRDD 裏面
 

import java.util

import org.apache.hadoop.hive.ql.exec.{UDFArgumentException, UDFArgumentLengthException}
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector}

class MyUDTF extends GenericUDTF {

  override def initialize(args: Array[ObjectInspector]): StructObjectInspector = {
    if (args.length != 1) {
      throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
    }
    if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
    }

    val fieldNames = new util.ArrayList[String]()
    val fieldOIs = new util.ArrayList[ObjectInspector]

    fieldNames.add("name")
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    fieldNames.add("age")
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)

  }


  override def close(): Unit = {}

  override def process(args: Array[AnyRef]): Unit = {
    val strings = args(0).toString.split(";")
    for (string <- strings) {
      val strings1 = string.split(":")
      forward(strings1)
    }
  }
}

UDTF測試類

import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.{StringType, StructField}
import org.apache.spark.sql.{Row, SparkSession, types}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{SparkConf, SparkContext}

object Test_UDTF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("Test_UDTF").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val ssc = new StreamingContext(sc, Seconds(5))
    val ds = ssc.socketTextStream("10.**,**,**", 8888)

    val session = SparkSession.builder()
      .master("local[2]")
      .appName("Test_UDTF")
      .enableHiveSupport()
      .getOrCreate()

    //該方式不能使用
    //session.sql("CREATE TEMPORARY FUNCTION udtf AS 'MyUDTF'")

    val sqlContext = new HiveContext(sc);
    sqlContext.sql("CREATE TEMPORARY FUNCTION udtf AS 'MyUDTF'")
    ds.foreachRDD {
      rdd =>

        val schema = types.StructType(
          Seq(
            StructField("info", StringType, true)
          )
        )
        val rowRDD = rdd.map(p => Row(p))
        //val dataFrame = session.createDataFrame(rowRDD, schema)
        val dataFrame = sqlContext.createDataFrame(rowRDD, schema)

        dataFrame.registerTempTable("test_udtf")
        val dataFrame1 = session.sql("select udtf(info)as (name,age) from test_udtf")
        dataFrame1.show();
    }
    ssc.start()
    ssc.awaitTermination()
  }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章