pyspark rdd def partitionBy自定義partitionFunc

partitionBy(self, numPartitions, partitionFunc=portable_hash): 函數裏主要有兩個參數,一個是numPartitions ,這個是分區的數量,大家都知道。

另一個是partitionFunc,這個分區的函數,默認是哈希函數。當然我們也可以來自定義:

data = sc.parallelize(['1', '2', '3', ]).map(lambda x: (x,x)).collect()

wp = data.partitionBy(data.count(),lambda k: int(k))

print wp.map(lambda t: t[0]).glom().collect()

這裏的自定義函數是最簡單的 lambda k: int(k),即根據自身的int值來分區。我們還可以根據需要定義其他更多的分區函數。

下面給出partitionBy的源碼:
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
“””
Return a copy of the RDD partitioned using the specified partitioner.

      >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) 
      >>> sets = pairs.partitionBy(2).glom().collect() 
     >>> set(sets[0]).intersection(set(sets[1])) 
      set([]) 
      """ 
      if numPartitions is None: 
          numPartitions = self._defaultReducePartitions() 

      # Transferring O(n) objects to Java is too expensive. 
      # Instead, we'll form the hash buckets in Python, 
      # transferring O(numPartitions) objects to Java. 
      # Each object is a (splitNumber, [objects]) pair. 
      # In order to avoid too huge objects, the objects are 
      # grouped into chunks. 
      outputSerializer = self.ctx._unbatched_serializer 

      limit = (_parse_memory(self.ctx._conf.get( 
          "spark.python.worker.memory", "512m")) / 2) 

      def add_shuffle_key(split, iterator): 

          buckets = defaultdict(list) 
          c, batch = 0, min(10 * numPartitions, 1000) 

          for (k, v) in iterator: 
              buckets[partitionFunc(k) % numPartitions].append((k, v)) 
              c += 1 

              # check used memory and avg size of chunk of objects 
              if (c % 1000 == 0 and get_used_memory() > limit 
                      or c > batch): 
                  n, size = len(buckets), 0 
                  for split in buckets.keys(): 
                      yield pack_long(split) 
                      d = outputSerializer.dumps(buckets[split]) 
                      del buckets[split] 
                      yield d 
                      size += len(d) 

                  avg = (size / n) >> 20 
                  # let 1M < avg < 10M 
                  if avg < 1: 
                      batch *= 1.5 
                  elif avg > 10: 
                    batch = max(batch / 1.5, 1) 
                c = 0 

          for (split, items) in buckets.iteritems(): 
              yield pack_long(split) 
              yield outputSerializer.dumps(items) 

      keyed = self.mapPartitionsWithIndex(add_shuffle_key) 
      keyed._bypass_serializer = True 
      with _JavaStackTrace(self.context) as st: 
          pairRDD = self.ctx._jvm.PairwiseRDD( 
              keyed._jrdd.rdd()).asJavaPairRDD() 
          partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, 
                                                        id(partitionFunc)) 
      jrdd = pairRDD.partitionBy(partitioner).values() 
      rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) 
      # This is required so that id(partitionFunc) remains unique, 
      # even if partitionFunc is a lambda: 
      rdd._partitionFunc = partitionFunc 
      return rdd 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章