這篇文章主要記錄spark高級數據分析書中,關於記錄關聯問題的代碼的剖析。
其全部代碼如下:
miaofudeMacBook-Pro:code miaofu$ git clone https://github.com/sryza/aas.git
Cloning into 'aas'...
remote: Counting objects: 2490, done.
remote: Compressing objects: 100% (17/17), done.
remote: Total 2490 (delta 4), reused 0 (delta 0), pack-reused 2473
Receiving objects: 100% (2490/2490), 477.02 KiB | 149.00 KiB/s, done.
Resolving deltas: 100% (695/695), done.
Checking connectivity... done.
miaofudeMacBook-Pro:code miaofu$ cd aas/
.git/ .travis.yml README.md ch03-recommender/ ch05-kmeans/ ch07-graph/ ch09-risk/ ch11-neuro/ pom.xml
.gitignore LICENSE ch02-intro/ ch04-rdf/ ch06-lsa/ ch08-geotime/ ch10-genomics/ common/ simplesparkproject/
miaofudeMacBook-Pro:code miaofu$ cd aas/
.git/ .travis.yml README.md ch03-recommender/ ch05-kmeans/ ch07-graph/ ch09-risk/ ch11-neuro/ pom.xml
.gitignore LICENSE ch02-intro/ ch04-rdf/ ch06-lsa/ ch08-geotime/ ch10-genomics/ common/ simplesparkproject/
miaofudeMacBook-Pro:code miaofu$ cd aas/
miaofudeMacBook-Pro:aas miaofu$ ls
LICENSE ch03-recommender ch06-lsa ch09-risk common
README.md ch04-rdf ch07-graph ch10-genomics pom.xml
ch02-intro ch05-kmeans ch08-geotime ch11-neuro simplesparkproject
miaofudeMacBook-Pro:aas miaofu$ vi ch02-intro/
pom.xml src/
miaofudeMacBook-Pro:aas miaofu$ vi ch02-intro/src/main/scala/com/cloudera/datascience/intro/RunIntro.scala
.....
ct.filter(s => s.score >= 4.0).
map(s => s.md.matched).countByValue().foreach(println)
ct.filter(s => s.score >= 2.0).
map(s => s.md.matched).countByValue().foreach(println)
}
def statsWithMissing(rdd: RDD[Array[Double]]): Array[NAStatCounter] = {
val nastats = rdd.mapPartitions((iter: Iterator[Array[Double]]) => {
val nas: Array[NAStatCounter] = iter.next().map(d => NAStatCounter(d))
iter.foreach(arr => {
nas.zip(arr).foreach { case (n, d) => n.add(d) }
})
Iterator(nas)
})
nastats.reduce((n1, n2) => {
n1.zip(n2).map { case (a, b) => a.merge(b) }
})
}
}
class NAStatCounter extends Serializable {
val stats: StatCounter = new StatCounter()
var missing: Long = 0
def add(x: Double): NAStatCounter = {
if (x.isNaN) {
missing += 1
} else {
stats.merge(x)
}
this
}
def merge(other: NAStatCounter): NAStatCounter = {
stats.merge(other.stats)
missing += other.missing
this
}
override def toString: String = {
"stats: " + stats.toString + " NaN: " + missing
}
}
object NAStatCounter extends Serializable {
def apply(x: Double) = new NAStatCounter().add(x)
}
(1)然後先分析類
import org.apache.spark.util.StatCounter
class NAStatCounter extends Serializable {
val stats: StatCounter = new StatCounter()
var missing: Long = 0
def add(x: Double): NAStatCounter = {
if (x.isNaN) {
missing += 1
} else {
stats.merge(x)
}
this
}
def merge(other: NAStatCounter): NAStatCounter = {
stats.merge(other.stats)
missing += other.missing
this
}
override def toString: String = {
"stats: " + stats.toString + " NaN: " + missing
}
}
object NAStatCounter extends Serializable {
def apply(x: Double) = new NAStatCounter().add(x)
}
注意這裏定義了一個scala的類,該類是繼承了StatCounter,這個類是spark定義的用於描述統計量的類 。而這個繼承類在其基礎上,包括瞭解決NaN的情況。值得注意的是NanStatCounter自身重新定義了merge。這個類在StatCounter也有定義,該函數,是通過增量式的方法重新計算描述統計量的值。該函數的返回值,是StatCounter這個類本身。這一點是因爲在後面作者要對所有的記錄RDD每一個記錄使用mapPartitions函數,爲什麼使用這個函數,後面在詳細敘述。當前我們注意到其實stats也有merge,這個merger是StatCounter定義的實現增量式計算描述統計量的函數。爲了一探究竟,我們可以打開這個類的定義文件(https://github.com/apache/spark/blob/v2.0.0/core/src/main/scala/org/apache/spark/util/StatCounter.scala):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
/**
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
* numerically robust way. Includes support for merging two StatCounters. Based on Welford
* and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
* for running variance.
*
* @constructor Initialize the StatCounter with the given values.
*/
class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
private var maxValue: Double = Double.NegativeInfinity // Running max of our values
private var minValue: Double = Double.PositiveInfinity // Running min of our values
merge(values)
/** Initialize the StatCounter with no values. */
def this() = this(Nil)
/** Add a value into this StatCounter, updating the internal statistics. */
def merge(value: Double): StatCounter = {
val delta = value - mu
n += 1
mu += delta / n
m2 += delta * (value - mu)
maxValue = math.max(maxValue, value)
minValue = math.min(minValue, value)
this
}
/** Add multiple values into this StatCounter, updating the internal statistics. */
def merge(values: TraversableOnce[Double]): StatCounter = {
values.foreach(v => merge(v))
this
}
/** Merge another StatCounter into this one, adding up the internal statistics. */
def merge(other: StatCounter): StatCounter = {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
if (n == 0) {
mu = other.mu
m2 = other.m2
n = other.n
maxValue = other.maxValue
minValue = other.minValue
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
mu = mu + (delta * other.n) / (n + other.n)
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n)
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
maxValue = math.max(maxValue, other.maxValue)
minValue = math.min(minValue, other.minValue)
}
this
}
}
/** Clone this StatCounter */
def copy(): StatCounter = {
val other = new StatCounter
other.n = n
other.mu = mu
other.m2 = m2
other.maxValue = maxValue
other.minValue = minValue
other
}
def count: Long = n
def mean: Double = mu
def sum: Double = n * mu
def max: Double = maxValue
def min: Double = minValue
/** Return the variance of the values. */
def variance: Double = {
if (n == 0) {
Double.NaN
} else {
m2 / n
}
}
/**
* Return the sample variance, which corrects for bias in estimating the variance by dividing
* by N-1 instead of N.
*/
def sampleVariance: Double = {
if (n <= 1) {
Double.NaN
} else {
m2 / (n - 1)
}
}
/** Return the standard deviation of the values. */
def stdev: Double = math.sqrt(variance)
/**
* Return the sample standard deviation of the values, which corrects for bias in estimating the
* variance by dividing by N-1 instead of N.
*/
def sampleStdev: Double = math.sqrt(sampleVariance)
override def toString: String = {
"(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
}
}
object StatCounter {
/** Build a StatCounter from a list of values. */
def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)
/** Build a StatCounter from a list of values passed as variable-length arguments. */
def apply(values: Double*): StatCounter = new StatCounter(values)
}
通過對這個函數的研究,也幫助了了解了scala類定義的一些技巧。首先是定義了private變量,存儲核心重要屬性,然後初始化,然後定義了三個merge函數,充分發揚了類的多態性。定義一個copy自身的方法。另外就是定義了一系列的函數變量,最後定義了toString方法獲取函數的目前的狀態。最後使用了object半生對象。
(2)使用RDD算子,統一處理
import org.apache.spark.rdd.RDD
def statsWithMissing(rdd: RDD[Array[Double]]): Array[NAStatCounter] = {
val nastats = rdd.mapPartitions((iter: Iterator[Array[Double]]) => {
val nas: Array[NAStatCounter] = iter.next().map(d => NAStatCounter(d))
iter.foreach(arr => {
nas.zip(arr).foreach { case (n, d) => n.add(d) }
})
Iterator(nas)
})
nastats.reduce((n1, n2) => {
n1.zip(n2).map { case (a, b) => a.merge(b) }
})
}
}
首先第一句是mapPartitions的一個操作,輸出爲nanstats,輸入是rdd變量。mapPartitions的參數一個函數,這個函數的輸入是Iterator[Arrays[Double]],輸出則是Iterator[Array[NAStatCounter]]。至於如何實現從我們自己定義的函數,到最終我們的目標。中間的過程都是通過Spark分佈式實現的。也就是mapPartitions的這個接口,參見下圖。寫到這裏,我們將可以理解了爲什麼說Spark是一個分佈式的編程框架了。
既然說到了mapPartitions這個函數,我們就探索一下這個函數的細節。
mapPartitions
def mapPartitions[U](f: (Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false)(implicit arg0: ClassTag[U]): RDD[U]
該函數和map函數類似,只不過映射函數的參數由RDD中的每一個元素變成了RDD中每一個分區的迭代器。如果在映射的過程中需要頻繁創建額外的對象,使用mapPartitions要比map高效的過。
比如,將RDD中的所有數據通過JDBC連接寫入數據庫,如果使用map函數,可能要爲每一個元素都創建一個connection,這樣開銷很大,如果使用mapPartitions,那麼只需要針對每一個分區建立一個connection。
參數preservesPartitioning表示是否保留父RDD的partitioner分區信息。
舉一個例子:
var rdd1= sc.makeRDD(1 to 5,2)
//rdd1有兩個分區
scala>var rdd3= rdd1.mapPartitions{ x =>{
|var result=List[Int]()
|var i=0
|while(x.hasNext){
| i+= x.next()
|}
| result.::(i).iterator
|}}
rdd3: org.apache.spark.rdd.RDD[Int]=MapPartitionsRDD[84] at mapPartitions at :23
//rdd3將rdd1中每個分區中的數值累加
scala> rdd3.collect
res65:Array[Int]=Array(3,12)
scala> rdd3.partitions.size
res66:Int=2
mapValues(function)
原RDD中的Key保持不變,與新的Value一起組成新的RDD中的元素。因此,該函數只適用於元素爲KV對的RDD
scala> val a = sc.parallelize(List("dog", "tiger", "lion", "cat", "panther", " eagle"), 2)
a: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[0] at parallelize at <console>:21
scala>
scala> val b = a.map(x => (x.length, x))
b: org.apache.spark.rdd.RDD[(Int, String)] = MapPartitionsRDD[1] at map at <console>:23
scala>
scala> b.mapValues("x" + _ + "x").collect
res0: Array[(Int, String)] = Array((3,xdogx), (5,xtigerx), (4,xlionx), (3,xcatx), (7,xpantherx), (6,x eaglex))
//"x" + _ + "x"等同於everyInput
=>"x" + everyInput + "x"
//結果
Array(
(3,xdogx),
(5,xtigerx),
(4,xlionx),
(3,xcatx),
(7,xpantherx),
(5,xeaglex)
)
flatMap(function)
與map類似,區別是原RDD中的元素經map處理後只能生成一個元素,而原RDD中的元素經flatmap處理後可生成多個元素
scala> val a = sc.parallelize(1 to 4, 2)
a: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[0] at parallelize at <console>:27
scala> a
res0: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[0] at parallelize at <console>:27
scala> val b = a.flatMap(x => 1 to x)//每個元素擴展
b: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[1] at flatMap at <console>:29
scala> b
res1: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[1] at flatMap at <console>:29
scala> b.collect
res2: Array[Int] = Array(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)
scala>
緊接着我們來分析源代碼的第二句話:
nastats.reduce((n1, n2) => {
n1.zip(n2).map { case (a, b) => a.merge(b) }
})
這句話關鍵是reduce,這個Spark函數接口,還記得nastats是一個RDD[Array[NAStatCounter]]類型的變量,是mapPartitions函數的輸出。如何對nastats是一個RDD[Array[NAStatCounter]]類型的變量,進一步的對所有結果進行聚合,最終得到我們想要的Array[NAStatCounter]的結果,便是reduce的工作,其實就是一個簡單的聚合。
reduce函數:
輸入是一個處理RDD記錄中兩個記錄,輸出則是一個聚合的記錄。對於這裏其實就是對於兩個Array[NAStatCounter]進行聚合,得到描述統計量Array[NAStatCounter]。
首先是對於連個Array[NAStatCounter]類型的變量n1,n2做一個zip的映射,此時Array[NAStatCounter]*Array[NAStatCounter]=》(zip)Array[NAStatCounter,NAStatCounter]。然後使用map算子,對於每一個(NAStatCounter,NAStatCounter)聚合成merge,返回一個NAStatCounter。這裏其實就是爲什麼NAStatCounter類定義的merge裏返回this的原因。注意這裏的zip,map都是在本地操作的,都是scala自帶的函數接口,與Spark無關的。在外面的reduce是Spark計算的。本地計算的結果要經過序列,壓縮,網絡傳輸,解壓,反序列化到master節點。給出最終的結果。這也就是爲什麼NAStatCounter的定義要extend
Serializable。