top N徹底解祕

本博文內容:

  1、基礎Top N算法實戰

  2、分組Top N算法實戰

  3、排序算法RangePartitioner內幕解密

 

 

 

 

 

 

 

 

 

 

1、基礎Top N算法實戰

     Top N是排序,Take是直接拿出幾個元素,沒排序。

 

 

 

  新建

 

複製代碼
1
4
2
5
7
3
2
7
9
1
4
5
複製代碼

 

 

 

 

 

 

 

 

  從源碼,來說話,take返回的是數組,不是RDD。而colletc需要的是RDD。

複製代碼
/**
 * Return an array that contains all of the elements in this RDD.
 */
def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}
複製代碼

 






 

複製代碼
/**
 * Take the first num elements of the RDD. It works by first scanning one partition, and use the
 * results from that partition to estimate the number of additional partitions needed to satisfy
 * the limit.
 *
 * @note due to complications in the internal implementation, this method will raise
 * an exception if called on an RDD of `Nothing` or `Null`.
 */
def take(num: Int): Array[T] = withScope {
  if (num == 0) {
    new Array[T](0)
  } else {
    val buf = new ArrayBuffer[T]
    val totalParts = this.partitions.length
    var partsScanned = 0
    while (buf.size < num && partsScanned < totalParts) {
      // The number of partitions to try in this iteration. It is ok for this number to be
      // greater than totalParts because we actually cap it at totalParts in runJob.
      var numPartsToTry = 1
      if (partsScanned > 0) {
        // If we didn't find any rows after the previous iteration, quadruple and retry.
        // Otherwise, interpolate the number of partitions we need to try, but overestimate
        // it by 50%. We also cap the estimation in the end.
        if (buf.size == 0) {
          numPartsToTry = partsScanned * 4
        } else {
          // the left side of max is >=1 whenever partsScanned >= 2
          numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
          numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
        }
      }

      val left = num - buf.size
      val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

      res.foreach(buf ++= _.take(num - buf.size))
      partsScanned += numPartsToTry
    }

    buf.toArray
  }
}
複製代碼

 








   則,所以,代碼,如下:

複製代碼
package com.zhouls.spark.cores

import org.apache.spark.{SparkConf, SparkContext}

/**
  * 基礎Top N實戰
  * Created by Administrator on 2016/10/9.
  */
object TopNBasic {
  def main(args: Array[String]) {
    val conf = new SparkConf()
    conf.setAppName("Top N Basically!").setMaster("local")
    val sc = new SparkContext(conf)
    val lines = sc.textFile("D://SoftWare//spark-1.5.2-bin-hadoop2.6//basicTopN.txt")
    val pairs = lines.map(line =>(line.toInt,line)) //生成key-value鍵值對,方便sortByKey進行排序
    val sortedPairs = pairs.sortByKey(false) //降序排序
    val sortedData = sortedPairs.map(pair => pair._2)  //只要是改變每一行列的數據,一般都是用map操作。過濾出排序後的內容本身
    val top5 = sortedData.take(5)  //獲取排名前5位的元素內容
    top5.foreach(println)
  }
}
複製代碼

 

 

 

好的,這裏,學個新知識點。

setLogLevel
看源碼

複製代碼
/** Control our logLevel. This overrides any user-defined log settings.
 * @param logLevel The desired log level as a string.
 * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
 */
def setLogLevel(logLevel: String) {
  val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN")
  if (!validLevels.contains(logLevel)) {
    throw new IllegalArgumentException(
      s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}")
  }
  Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel))
}
複製代碼

 





setLogLevel("ALL")

對應的打印輸出信息,

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program 
artitions
d size 1814.0 B, free 976.2 MB)
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] received message AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true) from Actor[akka://sparkDriver/temp/$g]
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: Received RPC message: AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true)
16/10/09 09:15:38 INFO BlockManagerInfo: Added broadcast_2_piece0 in memory on localhost:52833 (size: 1814.0 B, free: 976.3 MB)
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] handled message (3.09051 ms) AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true) from Actor[akka://sparkDriver/temp/$g]
16/10/09 09:15:38 DEBUG BlockManagerMaster: Updated info of block broadcast_2_piece0
16/10/09 09:15:38 DEBUG BlockManager: Told master about block broadcast_2_piece0
16/10/09 09:15:38 DEBUG BlockManager: Put block broadcast_2_piece0 locally took 8 ms
16/10/09 09:15:38 DEBUG BlockManager: Putting block broadcast_2_piece0 without replication took 9 ms
16/10/09 09:15:38 INFO SparkContext: Created broadcast 2 from broadcast at DAGScheduler.scala:861
 bytes)
16/10/09 09:15:39 TRACE DAGScheduler: failed: Set()
16/10/09 09:15:39 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 1.022280 s
9
7
7
5
5
16/10/09 09:15:39 INFO SparkContext: Invoking stop() from shutdown hook
age (5.094032 ms) AkkaMessage(StopCoordinator,false) from Actor[akka://sparkDriver/deadLetters]
16/10/09 09:15:39 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-3656d24c-bfdb-4def-b751-8d7fc84150cb

Process finished with exit code 0

 

 

setLogLevel("DEBUG")

對應的,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7534 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\cldrdata.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\dnsns.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\jaccess.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\jfxrt.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\localedata.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\nashorn.jar;C:\Program fun$28
16/10/09 09:18:05 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] handled message (2.022709 ms) AkkaMessage(StatusUpdate(1,FINISHED,java.nio.HeapByteBuffer[pos=0 lim=1185 cap=1185]),false) from Actor[akka://sparkDriver/deadLetters]
16/10/09 09:18:05 INFO TaskSetManager: Finished task 0.0 in stage 1.0 (TID 1) in 153 ms on localhost (1/1)
16/10/09 09:18:05 INFO TaskSchedulerImpl: Removed TaskSet 1.0, whose tasks have all completed, from pool 
16/10/09 09:18:05 INFO DAGScheduler: ResultStage 1 (take at TopNBasic.scala:20) finished in 0.163 s
16/10/09 09:18:05 DEBUG DAGScheduler: After removal of stage 1, remaining stages = 1
16/10/09 09:18:05 DEBUG DAGScheduler: After removal of stage 0, remaining stages = 0
16/10/09 09:18:05 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 0.985550 s
9
7
7
5
5
16/10/09 09:18:05 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:18:05 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:18:05 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-c9f238f3-9210-4f3a-a248-11f6f610163e

Process finished with exit code 0

 

 

setLogLevel("ERROR")

對應地,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7535 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:18:43 INFO BlockManagerMasterEndpoint: Registering block manager localhost:52966 with 976.3 MB RAM, BlockManagerId(driver, localhost, 52966)
16/10/09 09:18:43 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5
16/10/09 09:18:50 WARN QueuedThreadPool: 3 threads could not be stopped

Process finished with exit code 0




setLogLevel("FATAL")

對應地,打印輸出信息, 是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7536 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:20:17 INFO BlockManagerMasterEndpoint: Registering block manager localhost:53014 with 976.3 MB RAM, BlockManagerId(driver, localhost, 53014)
16/10/09 09:20:17 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5

Process finished with exit code 0

 

 

 

setLogLevel("INFO")

對應地,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7537 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:21:17 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 1.085930 s
9
7
7
5
5
16/10/09 09:21:17 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:21:17 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:21:17 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-de03b369-fec4-4785-abec-563c502d0bd7

Process finished with exit code 0

 

 

 

setLogLevel("OFF")

對應地,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7538 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program 
16/10/09 09:22:10 INFO BlockManagerMasterEndpoint: Registering block manager localhost:53098 with 976.3 MB RAM, BlockManagerId(driver, localhost, 53098)
16/10/09 09:22:10 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5

Process finished with exit code 0

 

 

setLogLevel("TRACE")

對應地,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7539 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:23:15 TRACE DAGScheduler: running: Set()
16/10/09 09:23:15 TRACE DAGScheduler: waiting: Set()
16/10/09 09:23:15 TRACE DAGScheduler: failed: Set()
16/10/09 09:23:15 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 0.985096 s
9
7
7
5
5
16/10/09 09:23:15 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:23:15 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:23:15 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-d3604805-b6e2-4873-a8aa-10cabda4f329

Process finished with exit code 0

 

 

 

setLogLevel("WARN")

對應地,打印輸出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7532 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program fe80:0:0:0:0:5efe:c0a8:bf02%net11, but we couldn't find any external IP address!
9
7
7
5
5

Process finished with exit code 0

總結:基礎Top N算法實戰至此。





2、分組Top N算法實戰
先從Java語言,來實戰

寫代碼

複製代碼
Spark 100
Hadoop 65
Spark 99
Hadoop 61
Spark 195
Hadoop 60
Spark 98
Hadoop 69
Spark 91
Hadoop 64
Spark 89
Hadoop 98
Spark 88 
Hadoop 99
Spark 68
Hadoop 60
Spark 79
Hadoop 97
Spark 69
Hadoop 96
複製代碼

 

 

 

 

複製代碼
package com.zhouls.spark.SparkApps.cores;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;

public class TopNGroup {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("TopNGroup").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); //其底層實際上就是Scala的SparkContext

JavaRDD<String> lines = sc.textFile("D://SoftWare//spark-1.5.2-bin-hadoop2.6//groupTopN.txt");


JavaPairRDD<String, Integer> pairs = lines.mapToPair(new PairFunction<String, String, Integer>() {
private static final long serialVersionUID =1L ;
@Override
public Tuple2<String, Integer> call(String line) throws Exception {

String[] splitedLine =line.split(" ");
System.out.println(splitedLine[0]);
return new Tuple2<String,Integer>(splitedLine[0],Integer.valueOf(splitedLine[1]));
}
});

JavaPairRDD<String, Iterable<Integer>> groupedPairs =pairs.groupByKey();

JavaPairRDD<String, Iterable<Integer>> top5=groupedPairs.mapToPair(new
PairFunction<Tuple2<String,Iterable<Integer>>, String, Iterable<Integer>>() {
/**
* 
*/
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Iterable<Integer>> call(Tuple2<String, Iterable<Integer>> groupedData)
throws Exception {
// TODO Auto-generated method stub

Integer[] top5=new Integer[5];
String groupedKey= groupedData._1;
Iterator<Integer> groupedValue = groupedData._2.iterator();

while(groupedValue.hasNext()){
Integer value = groupedValue.next();

for (int i =0; i<5; i++){
if (top5[i] ==null) {
top5[i] = value ;
break;
} else if (value > top5[i]) {
for (int j = 4; j > i; j--){
top5[j] = top5[j-1];
}
top5[i]=value;
break;
} 

}

} 



return new Tuple2<String, Iterable<Integer>>(groupedKey,Arrays.asList(top5));
}
}) ;

//打印分組後的Top N 
top5.foreach(new VoidFunction<Tuple2<String,Iterable<Integer>>>() {
@Override
public void call(Tuple2<String, Iterable<Integer>> topped) throws Exception {

System.out.println("Group key :"+ topped._1);//獲取Group key
Iterator<Integer> toppedValue = topped._2.iterator(); //獲取Group Value
while (toppedValue.hasNext()){ //具體打印出每組的Top N
Integer value =toppedValue.next();
System.out.println(value); 
}
System.out.println("******************************************************88");
}
});



}
}
複製代碼

 

 

 

 

 

 感謝下面的博主:

http://www.it610.com/article/5193051.htm

 

 

 

 

 

 

  若是groupTopN.txt的內容是:

複製代碼
Spark 100
Hadoop 62
Flink 77
Kafka 91
Hadoop 93
Spark 78
Hadoop 69
Spark 98
Hadoop 62
Spark 99
Hadoop 61
Spark 70
Hadoop 75
Spark 88
Hadoop 68
Spark 90
Hadoop 61
複製代碼

 

 

 

  則,對應地是,

 

 

 

 分組Top N算法實戰的總結:

分組TOPN排序
  1.讀入每行數據 JavaRDD<String> lines
  2、生成pairs K,V鍵值對  JavaPairRDD<String, Integer> pairs
       輸入一行的數據
       輸出的KEY值是名稱,Value是分數 Iterable;
  3、groupByKey按名稱進行分組: JavaPairRDD<String, Iterable<Integer>>  groupedPairs =pairs.groupByKey();
  4、分組以後進行排序
    輸入groupdata,其中 KEY是名稱的組名,VALUE是分數的集合
    輸出 KEY:分組排序以後的組名,VALUE:是排序以後的分數的集合 取5個值

    JavaPairRDD<String, Iterable<Integer>> top5=groupedPairs.mapToPair(new

  PairFunction<Tuple2<String,Iterable<Integer>>, String, Iterable<Integer>>() {

 

 

 

 

 

 

3、排序算法RangePartitioner內幕解密

複製代碼
/**
 * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
 * `collect` or `save` on the resulting RDD will return or output an ordered list of records
 * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
 * order of the keys).
 */
// TODO: this currently doesn't work on P other than Tuple2!
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
    : RDD[(K, V)] = self.withScope
{
  val part = new RangePartitioner(numPartitions, self, ascending)
  new ShuffledRDD[K, V, V](self, part)
    .setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
複製代碼
RangePartitioner主要是依賴的RDD的數據劃分成不同的範圍,關鍵的地方是不同的範圍是有序的。
RangePartitioner除了是結果有序的基石以外,最爲重要的是儘量保證每個Partition中的數據量是均勻的!

 

 

 


 Google的面試題:如何在一個不確定數據規模的範圍內,進行排序。

排序的幾個內容:

1、二分算法,將key值放入對於的分區

   在未接觸二分查找算法時,最通用的一種做法是,對數組進行遍歷,跟每個元素進行比較,其時間爲O(n).但二分查找算法則

更優,因爲其查找時間爲O(lgn),譬如數組{1, 2, 3, 4, 5, 6, 7, 8, 9},查找元素6,用二分查找的算法執行的話,

其順序爲:
    1.第一步查找中間元素,即5,由於5<6,則6必然在5之後的數組元素中,那麼就在{6, 7, 8, 9}中查找,
    2.尋找{6, 7, 8, 9}的中位數,爲7,7>6,則6應該在7左邊的數組元素中,那麼只剩下6,即找到了。

 

2、水桶抽樣算法,(適合數據規模是特別大,內存容納不下時的情況)以下乘以3的原因
   乘3的原因是RDD的分區可能有數據傾斜,sampleSize是期望的樣本大小,但是某些分區的數據量可能少於

sampleSize/PartitionNumber,乘以3後期望其他的分區可以多采樣點數據,使得總的採樣量達到或超過sampleSize。
     // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)

      // Assume the input partitions are roughly balanced and a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt

 

複製代碼
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.{ClassTag, classTag}
import scala.util.hashing.byteswap32

import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}

/**
 * An object that defines how the elements in a key-value pair RDD are partitioned by key.
 * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
 */
abstract class Partitioner extends Serializable {
  def numPartitions: Int
  def getPartition(key: Any): Int
}

object Partitioner {
  /**
   * Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
   *
   * If any of the RDDs already has a partitioner, choose that one.
   *
   * Otherwise, we use a default HashPartitioner. For the number of partitions, if
   * spark.default.parallelism is set, then we'll use the value from SparkContext
   * defaultParallelism, otherwise we'll use the max number of upstream partitions.
   *
   * Unless spark.default.parallelism is set, the number of partitions will be the
   * same as the number of partitions in the largest upstream RDD, as this should
   * be least likely to cause out-of-memory errors.
   *
   * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
   */
  def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
    val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
    for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
      return r.partitioner.get
    }
    if (rdd.context.conf.contains("spark.default.parallelism")) {
      new HashPartitioner(rdd.context.defaultParallelism)
    } else {
      new HashPartitioner(bySize.head.partitions.size)
    }
  }
}

/**
 * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
 * Java's `Object.hashCode`.
 *
 * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
 * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
 * produce an unexpected or incorrect result.
 */
class HashPartitioner(partitions: Int) extends Partitioner {
  require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

  def numPartitions: Int = partitions

  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner =>
      h.numPartitions == numPartitions
    case _ =>
      false
  }

  override def hashCode: Int = numPartitions
}

/**
 * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
 * equal ranges. The ranges are determined by sampling the content of the RDD passed in.
 *
 * Note that the actual number of partitions created by the RangePartitioner might not be the same
 * as the `partitions` parameter, in the case where the number of sampled records is less than
 * the value of `partitions`.
 */
class RangePartitioner[K : Ordering : ClassTag, V](
    @transient partitions: Int,
    @transient rdd: RDD[_ <: Product2[K, V]],
    private var ascending: Boolean = true)
  extends Partitioner {

  // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
  require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")

  private var ordering = implicitly[Ordering[K]]

  // An array of upper bounds for the first (partitions - 1) partitions
  private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)
      // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
      if (numItems == 0L) {
        Array.empty
      } else {
        // If a partition contains much more than the average number of items, we re-sample from it
        // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.size).toFloat
            for (key <- sample) {
              candidates += ((key, weight))
            }
          }
        }
        if (imbalancedPartitions.nonEmpty) {
          // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
          val seed = byteswap32(-rdd.id - 1)
          val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
          val weight = (1.0 / fraction).toFloat
          candidates ++= reSampled.map(x => (x, weight))
        }
        RangePartitioner.determineBounds(candidates, partitions)
      }
    }
  }

  def numPartitions: Int = rangeBounds.length + 1

  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

  override def equals(other: Any): Boolean = other match {
    case r: RangePartitioner[_, _] =>
      r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
    case _ =>
      false
  }

  override def hashCode(): Int = {
    val prime = 31
    var result = 1
    var i = 0
    while (i < rangeBounds.length) {
      result = prime * result + rangeBounds(i).hashCode
      i += 1
    }
    result = prime * result + ascending.hashCode
    result
  }

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => out.defaultWriteObject()
      case _ =>
        out.writeBoolean(ascending)
        out.writeObject(ordering)
        out.writeObject(binarySearch)

        val ser = sfactory.newInstance()
        Utils.serializeViaNestedStream(out, ser) { stream =>
          stream.writeObject(scala.reflect.classTag[Array[K]])
          stream.writeObject(rangeBounds)
        }
    }
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => in.defaultReadObject()
      case _ =>
        ascending = in.readBoolean()
        ordering = in.readObject().asInstanceOf[Ordering[K]]
        binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]

        val ser = sfactory.newInstance()
        Utils.deserializeViaNestedStream(in, ser) { ds =>
          implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
          rangeBounds = ds.readObject[Array[K]]()
        }
    }
  }
}

private[spark] object RangePartitioner {

  /**
   * Sketches the input RDD via reservoir sampling on each partition.
   *
   * @param rdd the input RDD to sketch
   * @param sampleSizePerPartition max sample size per partition
   * @return (total number of items, an array of (partitionId, number of items, sample))
   */
  def sketch[K : ClassTag](
      rdd: RDD[K],
      sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
    val shift = rdd.id
    // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
    val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
      val seed = byteswap32(idx ^ (shift << 16))
      val (sample, n) = SamplingUtils.reservoirSampleAndCount(
        iter, sampleSizePerPartition, seed)
      Iterator((idx, n, sample))
    }.collect()
    val numItems = sketched.map(_._2.toLong).sum
    (numItems, sketched)
  }

  /**
   * Determines the bounds for range partitioning from candidates with weights indicating how many
   * items each represents. Usually this is 1 over the probability used to sample this candidate.
   *
   * @param candidates unordered candidates with weights
   * @param partitions number of partitions
   * @return selected bounds
   */
  def determineBounds[K : Ordering : ClassTag](
      candidates: ArrayBuffer[(K, Float)],
      partitions: Int): Array[K] = {
    val ordering = implicitly[Ordering[K]]
    val ordered = candidates.sortBy(_._1)
    val numCandidates = ordered.size
    val sumWeights = ordered.map(_._2.toDouble).sum
    val step = sumWeights / partitions
    var cumWeight = 0.0
    var target = step
    val bounds = ArrayBuffer.empty[K]
    var i = 0
    var j = 0
    var previousBound = Option.empty[K]
    while ((i < numCandidates) && (j < partitions - 1)) {
      val (key, weight) = ordered(i)
      cumWeight += weight
      if (cumWeight > target) {
        // Skip duplicate values.
        if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
          bounds += key
          target += step
          j += 1
          previousBound = Some(key)
        }
      }
      i += 1
    }
    bounds.toArray
  }
}
複製代碼

 




 如,源碼中的

水桶抽樣算法,(適合數據規模是特別大,內存容納不下時的情況)以下乘以3的原因
   乘3的原因是RDD的分區可能有數據傾斜,sampleSize是期望的樣本大小,但是某些分區的數據量可能少於

sampleSize/PartitionNumber,乘以3後期望其他的分區可以多采樣點數據,使得總的採樣量達到或超過sampleSize。
     // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)

      // Assume the input partitions are roughly balanced and a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt

 

 

 

 

 

 

 

sketch源碼

複製代碼
/**
 * Sketches the input RDD via reservoir sampling on each partition.
 *
 * @param rdd the input RDD to sketch
 * @param sampleSizePerPartition max sample size per partition
 * @return (total number of items, an array of (partitionId, number of items, sample))
 */
def sketch[K : ClassTag](
    rdd: RDD[K],
    sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
  val shift = rdd.id
  // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
  val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
    val seed = byteswap32(idx ^ (shift << 16))
    val (sample, n) = SamplingUtils.reservoirSampleAndCount(
      iter, sampleSizePerPartition, seed)
    Iterator((idx, n, sample))
  }.collect()
  val numItems = sketched.map(_._2.toLong).sum
  (numItems, sketched)
}
複製代碼

 




 

   reservoirSampleAndCount源碼

 

複製代碼
/**
 * Reservoir sampling implementation that also returns the input size.
 *
 * @param input input size
 * @param k reservoir size
 * @param seed random seed
 * @return (samples, input size)
 */
def reservoirSampleAndCount[T: ClassTag](
    input: Iterator[T],
    k: Int,
    seed: Long = Random.nextLong())
  : (Array[T], Int) = {
  val reservoir = new Array[T](k)
  // Put the first k elements in the reservoir.
  var i = 0
  while (i < k && input.hasNext) {
    val item = input.next()
    reservoir(i) = item
    i += 1
  }

  // If we have consumed all the elements, return them. Otherwise do the replacement.
  if (i < k) {
    // If input size < k, trim the array to return only an array of input size.
    val trimReservoir = new Array[T](i)
    System.arraycopy(reservoir, 0, trimReservoir, 0, i)
    (trimReservoir, i)
  } else {
    // If input size > k, continue the sampling process.
    val rand = new XORShiftRandom(seed)
    while (input.hasNext) {
      val item = input.next()
      val replacementIndex = rand.nextInt(i)
      if (replacementIndex < k) {
        reservoir(replacementIndex) = item
      }
      i += 1
    }
    (reservoir, i)
  }
}
複製代碼

 




 

 


getPartition源碼

複製代碼
def getPartition(key: Any): Int = {
  val k = key.asInstanceOf[K]
  var partition = 0
  if (rangeBounds.length <= 128) {
    // If we have less than 128 partitions naive search
    while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
      partition += 1
    }
  } else {
    // Determine which binary search method to use only once.
    partition = binarySearch(rangeBounds, k)
    // binarySearch either returns the match location or -[insertion point]-1
    if (partition < 0) {
      partition = -partition-1
    }
    if (partition > rangeBounds.length) {
      partition = rangeBounds.length
    }
  }
  if (ascending) {
    partition
  } else {
    rangeBounds.length - partition
  }
}
複製代碼

 




 

    二分算法,將key值放入對於的分區

   在未接觸二分查找算法時,最通用的一種做法是,對數組進行遍歷,跟每個元素進行比較,其時間爲O(n).但二分查找算法則

更優,因爲其查找時間爲O(lgn),譬如數組{1, 2, 3, 4, 5, 6, 7, 8, 9},查找元素6,用二分查找的算法執行的話,

   其順序爲:
    1.第一步查找中間元素,即5,由於5<6,則6必然在5之後的數組元素中,那麼就在{6, 7, 8, 9}中查找,
    2.尋找{6, 7, 8, 9}的中位數,爲7,7>6,則6應該在7左邊的數組元素中,那麼只剩下6,即找到了。




二分算法,確定,具體key屬於哪個分區,然後,之後,就可以用RangePartitioner了。
更多,見
http://www.it610.com/article/5193051.htm.
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章