Spark_Spark 中Cache的作用 以及 具體的案例

參考文章:

https://blog.csdn.net/qq_20641565/article/details/76216417

 

今天模擬實現 broadcastJoin 的時候突然意識到了這個點,對 Spark 的 Cache 做個總結。

 

問題

   

     在Spark中有時候我們很多地方都會用到同一個RDD, 按照常規的做法的話,那麼每個地方遇到Action操作的時候都會對同一個算子計算多次。這樣會造成效率低下的問題 !!!! 

 

常見 transform , action 算子 =>

https://blog.csdn.net/u010003835/article/details/106341908

 

例如:

val rdd1 = sc.textFile("xxx")

rdd1.xxxxx.xxxx.collect

rdd1.xxx.xxcollect

 

 

方法


   上面就是兩個代碼都用到了rdd1這個RDD,如果程序執行的話,那麼sc.textFile(“xxx”)就要被執行兩次,  可以把rdd1的結果進行cache到內存中,使用如下方法

val rdd1 = sc.textFile("xxx")

val rdd2 = rdd1.cache

rdd2.xxxxx.xxxx.collect

rdd2.xxx.xxcollect

 

示例

例如 如下Demo

package com.spark.test.offline.skewed_data

import org.apache.spark.SparkConf
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

/**
  * Created by szh on 2020/6/5.
  */
object JOINSkewedData2 {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf

    sparkConf
      .setAppName("JOINSkewedData")
      .set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
      //.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
      .set("spark.sql.shuffle.partitions", "3")

    if (args.length > 0 && args(0).equals("ide")) {
      sparkConf
        .setMaster("local[3]")
    }

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()


    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")
    //sparkContext.setCheckpointDir("")

    val userArr = new ArrayBuffer[(Int, String)]()
    val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")

    val threshold = 1000000

    for (i <- 1 to threshold) {

      var id = 10
      if (i < (threshold * 0.9)) {
        id = 1
      } else {
        id = i
      }
      val name = nameArr(Random.nextInt(5))

      userArr.+=((id, name))
    }

    val rddA = sparkContext
      .parallelize(userArr)

    //spark.sql("CACHE TABLE userA")

    //-----------------------------------------
    //---------------------------------------

    val arrList = new ArrayBuffer[(Int, Int)]

    for (i <- 1 to (threshold * 0.1).toInt) {
      val id = i
      val salary = Random.nextInt(100)

      arrList.+=((id, salary))
    }


    val rddB = sparkContext
      .parallelize(arrList)

    val broadData: Broadcast[Array[(Int, Int)]] = sparkContext.broadcast(rddB.collect())


    import scala.util.control._

    val resultRdd = rddA
      .mapPartitions(arr => {

        val broadVal = broadData.value
        var rowArr = new ArrayBuffer[Row]()
        val broadMap = new mutable.HashMap[Int, Int]()

        while (arr.hasNext) {

          val x = arr.next
          val loop = new Breaks
          var rRow: Row = null
          //var rRow: Option[Row] = None

          loop.breakable(
            for (tmpVal <- broadVal) {
              if (tmpVal._1 == x._1) {
                rRow = Row(tmpVal._1, x._2, tmpVal._2)
                //println(rRow)
                loop.break
              }
            }
          )
          if (rRow != null) {
            rowArr.+=(rRow)
            rRow = null
          }
        }

        println(rowArr.size)

        rowArr.iterator
      })
    //      .filter(x => {
    //        x match {
    //          case None => false
    //          case _ => true
    //        }
    //      })


    val resultStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
        , StructField("salary", IntegerType, nullable = true)
      )
    )

    spark
      .createDataFrame(resultRdd, resultStruct)
      .createOrReplaceTempView("resultB")

    val resultDF = spark
      .sql("SELECT uid, name, salary FROM resultB")

    //resultDF.checkpoint()
    resultDF.cache()

    resultDF.foreach(x => {
      val i = 1
    })

    println(resultDF.count())

    resultDF.show()
    resultDF.explain(true)

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }

}

注意其中 

    resultDF.foreach(x => {
      val i = 1
    })

    println(resultDF.count())

    resultDF.show()

foreach, count , show  是 3個 Action 操作 !!

不對 resultDF 進行 cache, 整個任務的執行時間 如下圖 :

 

 

對 resultDF 進行 cache, 整個任務的執行時間 如下圖 :

對比上圖,可以清楚的看到沒有進行 cache, count 對上游又重新計算了一遍多了20多秒 !!!!!

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章