Spark學習-數據關聯問題

這篇文章主要記錄spark高級數據分析書中,關於記錄關聯問題的代碼的剖析。

其全部代碼如下:

miaofudeMacBook-Pro:code miaofu$ git clone https://github.com/sryza/aas.git
Cloning into 'aas'...
remote: Counting objects: 2490, done.
remote: Compressing objects: 100% (17/17), done.
remote: Total 2490 (delta 4), reused 0 (delta 0), pack-reused 2473
Receiving objects: 100% (2490/2490), 477.02 KiB | 149.00 KiB/s, done.
Resolving deltas: 100% (695/695), done.
Checking connectivity... done.
miaofudeMacBook-Pro:code miaofu$ cd aas/
.git/               .travis.yml         README.md           ch03-recommender/   ch05-kmeans/        ch07-graph/         ch09-risk/          ch11-neuro/         pom.xml
.gitignore          LICENSE             ch02-intro/         ch04-rdf/           ch06-lsa/           ch08-geotime/       ch10-genomics/      common/             simplesparkproject/
miaofudeMacBook-Pro:code miaofu$ cd aas/
.git/               .travis.yml         README.md           ch03-recommender/   ch05-kmeans/        ch07-graph/         ch09-risk/          ch11-neuro/         pom.xml
.gitignore          LICENSE             ch02-intro/         ch04-rdf/           ch06-lsa/           ch08-geotime/       ch10-genomics/      common/             simplesparkproject/
miaofudeMacBook-Pro:code miaofu$ cd aas/
miaofudeMacBook-Pro:aas miaofu$ ls
LICENSE			ch03-recommender	ch06-lsa		ch09-risk		common
README.md		ch04-rdf		ch07-graph		ch10-genomics		pom.xml
ch02-intro		ch05-kmeans		ch08-geotime		ch11-neuro		simplesparkproject
miaofudeMacBook-Pro:aas miaofu$ vi ch02-intro/
pom.xml  src/     
miaofudeMacBook-Pro:aas miaofu$ vi ch02-intro/src/main/scala/com/cloudera/datascience/intro/RunIntro.scala 
.....
    ct.filter(s => s.score >= 4.0).
      map(s => s.md.matched).countByValue().foreach(println)
    ct.filter(s => s.score >= 2.0).
      map(s => s.md.matched).countByValue().foreach(println)
  }

  def statsWithMissing(rdd: RDD[Array[Double]]): Array[NAStatCounter] = {
    val nastats = rdd.mapPartitions((iter: Iterator[Array[Double]]) => {
      val nas: Array[NAStatCounter] = iter.next().map(d => NAStatCounter(d))
      iter.foreach(arr => {
        nas.zip(arr).foreach { case (n, d) => n.add(d) }
      })
      Iterator(nas)
    })
    nastats.reduce((n1, n2) => {
      n1.zip(n2).map { case (a, b) => a.merge(b) }
    })
  }
}

class NAStatCounter extends Serializable {
  val stats: StatCounter = new StatCounter()
  var missing: Long = 0

  def add(x: Double): NAStatCounter = {
    if (x.isNaN) {
      missing += 1
    } else {
      stats.merge(x)
    }
    this
  }

  def merge(other: NAStatCounter): NAStatCounter = {
    stats.merge(other.stats)
    missing += other.missing
    this
  }

  override def toString: String = {
    "stats: " + stats.toString + " NaN: " + missing
  }
}

object NAStatCounter extends Serializable {
  def apply(x: Double) = new NAStatCounter().add(x)
}



(1)然後先分析類

import org.apache.spark.util.StatCounter
class NAStatCounter extends Serializable {
  val stats: StatCounter = new StatCounter()
  var missing: Long = 0

  def add(x: Double): NAStatCounter = {
    if (x.isNaN) {
      missing += 1
    } else {
      stats.merge(x)
    }
    this
  }

  def merge(other: NAStatCounter): NAStatCounter = {
    stats.merge(other.stats)
    missing += other.missing
    this
  }

  override def toString: String = {
    "stats: " + stats.toString + " NaN: " + missing
  }
}

object NAStatCounter extends Serializable {
  def apply(x: Double) = new NAStatCounter().add(x)
}

注意這裏定義了一個scala的類,該類是繼承了StatCounter,這個類是spark定義的用於描述統計量的類 。而這個繼承類在其基礎上,包括瞭解決NaN的情況。值得注意的是NanStatCounter自身重新定義了merge。這個類在StatCounter也有定義,該函數,是通過增量式的方法重新計算描述統計量的值。該函數的返回值,是StatCounter這個類本身。這一點是因爲在後面作者要對所有的記錄RDD每一個記錄使用mapPartitions函數,爲什麼使用這個函數,後面在詳細敘述。當前我們注意到其實stats也有merge,這個merger是StatCounter定義的實現增量式計算描述統計量的函數。爲了一探究竟,我們可以打開這個類的定義文件(https://github.com/apache/spark/blob/v2.0.0/core/src/main/scala/org/apache/spark/util/StatCounter.scala):

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.util

/**
 * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
 * numerically robust way. Includes support for merging two StatCounters. Based on Welford
 * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
 * for running variance.
 *
 * @constructor Initialize the StatCounter with the given values.
 */
class StatCounter(values: TraversableOnce[Double]) extends Serializable {
  private var n: Long = 0     // Running count of our values
  private var mu: Double = 0  // Running mean of our values
  private var m2: Double = 0  // Running variance numerator (sum of (x - mean)^2)
  private var maxValue: Double = Double.NegativeInfinity // Running max of our values
  private var minValue: Double = Double.PositiveInfinity // Running min of our values

  merge(values)

  /** Initialize the StatCounter with no values. */
  def this() = this(Nil)

  /** Add a value into this StatCounter, updating the internal statistics. */
  def merge(value: Double): StatCounter = {
    val delta = value - mu
    n += 1
    mu += delta / n
    m2 += delta * (value - mu)
    maxValue = math.max(maxValue, value)
    minValue = math.min(minValue, value)
    this
  }

  /** Add multiple values into this StatCounter, updating the internal statistics. */
  def merge(values: TraversableOnce[Double]): StatCounter = {
    values.foreach(v => merge(v))
    this
  }

  /** Merge another StatCounter into this one, adding up the internal statistics. */
  def merge(other: StatCounter): StatCounter = {
    if (other == this) {
      merge(other.copy())  // Avoid overwriting fields in a weird order
    } else {
      if (n == 0) {
        mu = other.mu
        m2 = other.m2
        n = other.n
        maxValue = other.maxValue
        minValue = other.minValue
      } else if (other.n != 0) {
        val delta = other.mu - mu
        if (other.n * 10 < n) {
          mu = mu + (delta * other.n) / (n + other.n)
        } else if (n * 10 < other.n) {
          mu = other.mu - (delta * n) / (n + other.n)
        } else {
          mu = (mu * n + other.mu * other.n) / (n + other.n)
        }
        m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
        n += other.n
        maxValue = math.max(maxValue, other.maxValue)
        minValue = math.min(minValue, other.minValue)
      }
      this
    }
  }

  /** Clone this StatCounter */
  def copy(): StatCounter = {
    val other = new StatCounter
    other.n = n
    other.mu = mu
    other.m2 = m2
    other.maxValue = maxValue
    other.minValue = minValue
    other
  }

  def count: Long = n

  def mean: Double = mu

  def sum: Double = n * mu

  def max: Double = maxValue

  def min: Double = minValue

  /** Return the variance of the values. */
  def variance: Double = {
    if (n == 0) {
      Double.NaN
    } else {
      m2 / n
    }
  }

  /**
   * Return the sample variance, which corrects for bias in estimating the variance by dividing
   * by N-1 instead of N.
   */
  def sampleVariance: Double = {
    if (n <= 1) {
      Double.NaN
    } else {
      m2 / (n - 1)
    }
  }

  /** Return the standard deviation of the values. */
  def stdev: Double = math.sqrt(variance)

  /**
   * Return the sample standard deviation of the values, which corrects for bias in estimating the
   * variance by dividing by N-1 instead of N.
   */
  def sampleStdev: Double = math.sqrt(sampleVariance)

  override def toString: String = {
    "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
  }
}

object StatCounter {
  /** Build a StatCounter from a list of values. */
  def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)

  /** Build a StatCounter from a list of values passed as variable-length arguments. */
  def apply(values: Double*): StatCounter = new StatCounter(values)
}

通過對這個函數的研究,也幫助了了解了scala類定義的一些技巧。首先是定義了private變量,存儲核心重要屬性,然後初始化,然後定義了三個merge函數,充分發揚了類的多態性。定義一個copy自身的方法。另外就是定義了一系列的函數變量,最後定義了toString方法獲取函數的目前的狀態。最後使用了object半生對象。


(2)使用RDD算子,統一處理

import org.apache.spark.rdd.RDD
def statsWithMissing(rdd: RDD[Array[Double]]): Array[NAStatCounter] = {
    val nastats = rdd.mapPartitions((iter: Iterator[Array[Double]]) => {
      val nas: Array[NAStatCounter] = iter.next().map(d => NAStatCounter(d))
      iter.foreach(arr => {
        nas.zip(arr).foreach { case (n, d) => n.add(d) }
      })
      Iterator(nas)
    })
    nastats.reduce((n1, n2) => {
      n1.zip(n2).map { case (a, b) => a.merge(b) }
    })
  }
}


上面就是一個函數的變量的定義,輸入是RDD[Arrays[Double]] ,這個輸入對應其實就是關聯數據的輸入(每一行屬性字段字段都轉化爲Arrays[Double]),輸出則是一個簡單的Array[NAStatCounter]。

首先第一句是mapPartitions的一個操作,輸出爲nanstats,輸入是rdd變量。mapPartitions的參數一個函數,這個函數的輸入是Iterator[Arrays[Double]],輸出則是Iterator[Array[NAStatCounter]]。至於如何實現從我們自己定義的函數,到最終我們的目標。中間的過程都是通過Spark分佈式實現的。也就是mapPartitions的這個接口,參見下圖。寫到這裏,我們將可以理解了爲什麼說Spark是一個分佈式的編程框架了。


既然說到了mapPartitions這個函數,我們就探索一下這個函數的細節。

mapPartitions

def mapPartitions[U](f: (Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false)(implicit arg0: ClassTag[U]): RDD[U]

該函數和map函數類似,只不過映射函數的參數由RDD中的每一個元素變成了RDD中每一個分區的迭代器。如果在映射的過程中需要頻繁創建額外的對象,使用mapPartitions要比map高效的過。

比如,將RDD中的所有數據通過JDBC連接寫入數據庫,如果使用map函數,可能要爲每一個元素都創建一個connection,這樣開銷很大,如果使用mapPartitions,那麼只需要針對每一個分區建立一個connection。

參數preservesPartitioning表示是否保留父RDD的partitioner分區信息。

舉一個例子:

var rdd1= sc.makeRDD(1 to 5,2)
//rdd1有兩個分區
scala>var rdd3= rdd1.mapPartitions{ x =>{
|var result=List[Int]()
|var i=0
|while(x.hasNext){
| i+= x.next()
|}
| result.::(i).iterator
|}}
rdd3: org.apache.spark.rdd.RDD[Int]=MapPartitionsRDD[84] at mapPartitions at :23
 
//rdd3將rdd1中每個分區中的數值累加
scala> rdd3.collect
res65:Array[Int]=Array(3,12)
scala> rdd3.partitions.size
res66:Int=2



在這裏,我們就說一下不同map的一些比較

mapValues(function) 
原RDD中的Key保持不變,與新的Value一起組成新的RDD中的元素。因此,該函數只適用於元素爲KV對的RDD

scala> val a = sc.parallelize(List("dog", "tiger", "lion", "cat", "panther", " eagle"), 2)
a: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[0] at parallelize at <console>:21

scala> 

scala> val b = a.map(x => (x.length, x))
b: org.apache.spark.rdd.RDD[(Int, String)] = MapPartitionsRDD[1] at map at <console>:23

scala> 

scala> b.mapValues("x" + _ + "x").collect
res0: Array[(Int, String)] = Array((3,xdogx), (5,xtigerx), (4,xlionx), (3,xcatx), (7,xpantherx), (6,x eaglex))


//"x" + _ + "x"等同於everyInput =>"x" + everyInput + "x" 
//結果 
Array( 
(3,xdogx), 
(5,xtigerx), 
(4,xlionx), 
(3,xcatx), 
(7,xpantherx), 
(5,xeaglex) 
)

flatMap(function) 
與map類似,區別是原RDD中的元素經map處理後只能生成一個元素,而原RDD中的元素經flatmap處理後可生成多個元素

scala> val a = sc.parallelize(1 to 4, 2)
a: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[0] at parallelize at <console>:27

scala> a
res0: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[0] at parallelize at <console>:27

scala> val b = a.flatMap(x => 1 to x)//每個元素擴展
b: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[1] at flatMap at <console>:29

scala> b
res1: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[1] at flatMap at <console>:29

scala> b.collect
res2: Array[Int] = Array(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)                          

scala> 


緊接着我們來分析源代碼的第二句話:

 nastats.reduce((n1, n2) => {
      n1.zip(n2).map { case (a, b) => a.merge(b) }
    })

這句話關鍵是reduce,這個Spark函數接口,還記得nastats是一個RDD[Array[NAStatCounter]]類型的變量,是mapPartitions函數的輸出。如何對nastats是一個RDD[Array[NAStatCounter]]類型的變量,進一步的對所有結果進行聚合,最終得到我們想要的Array[NAStatCounter]的結果,便是reduce的工作,其實就是一個簡單的聚合。
reduce函數:
輸入是一個處理RDD記錄中兩個記錄,輸出則是一個聚合的記錄。對於這裏其實就是對於兩個Array[NAStatCounter]進行聚合,得到描述統計量Array[NAStatCounter]。
首先是對於連個Array[NAStatCounter]類型的變量n1,n2做一個zip的映射,此時Array[NAStatCounter]*Array[NAStatCounter]=》(zip)Array[NAStatCounter,NAStatCounter]。然後使用map算子,對於每一個(NAStatCounter,NAStatCounter)聚合成merge,返回一個NAStatCounter。這裏其實就是爲什麼NAStatCounter類定義的merge裏返回this的原因。注意這裏的zip,map都是在本地操作的,都是scala自帶的函數接口,與Spark無關的。在外面的reduce是Spark計算的。本地計算的結果要經過序列,壓縮,網絡傳輸,解壓,反序列化到master節點。給出最終的結果。這也就是爲什麼NAStatCounter的定義要extend Serializable。










發佈了34 篇原創文章 · 獲贊 24 · 訪問量 22萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章