Spark GraphX之Dijkstra(單源最短路徑)、Prime(最小生成樹)、FloydWarshall(多源最短路徑)

在這裏插入圖片描述
ShortestPaths的源碼如下:

package org.apache.spark.graphx.lib

import scala.reflect.ClassTag

import org.apache.spark.graphx._

/**
 * Computes shortest paths to the given set of landmark vertices, returning a graph where each
 * vertex attribute is a map containing the shortest-path distance to each reachable landmark.
 */
object ShortestPaths {
  /** Stores a map from the vertex id of a landmark to the distance to that landmark. */
  type SPMap = Map[VertexId, Int]

  private def makeMap(x: (VertexId, Int)*) = Map(x: _*)

  private def incrementMap(spmap: SPMap): SPMap = spmap.map { case (v, d) => v -> (d + 1) }

  private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap =
    (spmap1.keySet ++ spmap2.keySet).map {
      k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue))
    }.toMap

  /**
   * Computes shortest paths to the given set of landmark vertices.
   *
   * @tparam ED the edge attribute type (not used in the computation)
   *
   * @param graph the graph for which to compute the shortest paths
   * @param landmarks the list of landmark vertex ids. Shortest paths will be computed to each
   * landmark.
   *
   * @return a graph where each vertex attribute is a map containing the shortest-path distance to
   * each reachable landmark vertex.
   */
  def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
    val spGraph = graph.mapVertices { (vid, attr) =>
      if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
    }

    val initialMessage = makeMap()

    def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = {
      addMaps(attr, msg)
    }

    def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
      val newAttr = incrementMap(edge.dstAttr)
      if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
      else Iterator.empty
    }

    Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps)
  }
}

關於單源最短路徑,我們可以調用 ShortestPaths .run(graph, landmarks) 得到graph中的頂點到landmarks的“距離”,但是這個“距離”只是“跳數”。換句話說,只在graph中每條邊的權重都爲1的情況下,才能保證結果的正確性。而現實情況中,往往都不滿足這個條件。那麼問題來了,我們該如何做呢?學過圖論的朋友都知道,Dijkstra算法可以解決這個問題。遺憾的是,GraphX目前(Spark2.0.2)並未提供這樣的API,所以基於GraphX實現Dijkstra算法變得很有必要。

Dijkstra(單源最短路徑)

  //單源最短路徑
  def dijkstra[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
    //初始化,其中屬性爲(boolean, double,Long)類型,boolean用於標記是否訪問過,double爲頂點距離原點的距離,Long是上一個頂點的id
    var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))

    for(i <- 1L to g.vertices.count()) {
      //從沒有訪問過的頂點中找出距離原點最近的點
      val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
      //更新currentVertexId鄰接頂點的‘double’值
      val newDistances = g2.aggregateMessages[(Double, Long)](
        triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) {    //只給未確定的頂點發送消息
          triplet.sendToDst((triplet.srcAttr._2 + triplet.attr, triplet.srcId))
        },
        (x, y) => if(x._1 < y._1) x else y ,
        TripletFields.All
      )
      //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
      //更新圖形
      g2 = g2.outerJoinVertices(newDistances) {
        case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
        case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
      }
      //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
    }

    //g2
    g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
  }

Prime(最小生成樹)

知道Dijkstra算法的人也一定知道Prime算法。

  //最小生成樹
  def prime[VD: ClassTag](g : Graph[VD, Double], origin: VertexId) = {
    //初始化,其中屬性爲(boolean, double,Long)類型,boolean用於標記是否訪問過,double爲加入當前頂點的代價,Long是上一個頂點的id
    var g2 = g.mapVertices((vid, _) => (false, if(vid == origin) 0 else Double.MaxValue, -1L))

    for(i <- 1L to g.vertices.count()) {
      //從沒有訪問過的頂點中找出 代價最小 的點
      val currentVertexId = g2.vertices.filter(! _._2._1).reduce((a,b) => if (a._2._2 < b._2._2) a else b)._1
      //更新currentVertexId鄰接頂點的‘double’值
      val newDistances = g2.aggregateMessages[(Double, Long)](
        triplet => if(triplet.srcId == currentVertexId && !triplet.dstAttr._1) {    //只給未確定的頂點發送消息
          triplet.sendToDst((triplet.attr, triplet.srcId))
        },
        (x, y) => if(x._1 < y._1) x else y ,
        TripletFields.All
      )
      //newDistances.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
      //更新圖形
      g2 = g2.outerJoinVertices(newDistances) {
        case (vid, vd, Some(newSum)) => (vd._1 || vid == currentVertexId, math.min(vd._2, newSum._1), if(vd._2 <= newSum._1) vd._3 else newSum._2 )
        case (vid, vd, None) => (vd._1|| vid == currentVertexId, vd._2, vd._3)
      }
      //g2.vertices.foreach(x => println("currentVertexId\t"+currentVertexId+"\t->\t"+x))
    }

    //g2
    g.outerJoinVertices(g2.vertices)( (vid, srcAttr, dist) => (srcAttr, dist.getOrElse(false, Double.MaxValue, -1)._2, dist.getOrElse(false, Double.MaxValue, -1)._3) )
  }

FloydWarshall(多源最短路徑)

  //多源最短路徑
  def floydWarshall[VD: ClassTag](g: Graph[VD, Double]) = {
    def mergeMaps(a: Map[VertexId, Double], b: Map[VertexId, Double]) = {
      (a.keySet ++ b.keySet).map{ k => (k, math.min(a.getOrElse(k, Double.MaxValue), b.getOrElse(k, Double.MaxValue))) }.toMap
    }

    val N = g.vertices.count()    //圖頂點的個數
    var n = -1
    //初始化圖
    var g2 = g.mapVertices( (vid, _) => Map(vid -> 0.0) )

    //當n = N*N時,退出循環。注:不難發現最終結果是一個實對稱矩陣
    while(n < N * N) {
      val newVertices = g2.aggregateMessages[Map[VertexId, Double]](
        triplet =>{
          val dstPlus = triplet.dstAttr.map{ case (vid, distance) => (vid, triplet.attr+distance) }
          if(dstPlus != triplet.srcAttr) { triplet.sendToSrc(dstPlus) }
        },
        (a, b) => mergeMaps(a, b) ,
        TripletFields.Dst
      )

      g2 = g2.outerJoinVertices(newVertices)( (_, oldAttr, opt) => mergeMaps(oldAttr, opt.get) )

      n = g2.vertices.map{ case (vid, srcAttr) => srcAttr.size }.reduce(_ + _)
      //println("number\t" + n)
    }

    g2
  }

紙上得來終覺淺,絕知此事要躬行。下面開始實戰、實戰、實戰,重要的事情說三遍!!!

    val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"), (4L, "D"), (5L, "E"), (6L, "F"), (7L, "G")))
    val initialEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0),
                                   Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0),
                                   Edge(3L, 5L, 5.0),
                                   Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0),
                                   Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0),
                                   Edge(6L, 7L, 11.0)))
    val myEdges = initialEdges.filter(e => e.srcId != e.dstId).flatMap(e => Array(e, Edge(e.dstId, e.srcId, e.attr))).distinct()  //去掉自循環邊,有向圖變爲無向圖,去除重複邊
    val myGraph = Graph(myVertices, myEdges).cache()

    println(ShortestPaths.run(myGraph, Seq(3)).vertices.collect().mkString(","))
    println(dijkstra(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
    println(prime(myGraph, 3L).vertices.map(x => (x._1, x._2)).collect().mkString(" | "))
    floydWarshall(myGraph).vertices.foreach(println)

輸出依次如下:

ShortestPaths:
(1,Map(3 -> 2)) | (2,Map(3 -> 1)) | (3,Map(3 -> 0)) | (4,Map(3 -> 2)) | (5,Map(3 -> 1)) | (6,Map(3 -> 2)) | (7,Map(3 -> 2))
Dijkstra:
(1,(A,15.0,2)) | (2,(B,8.0,3)) | (3,(C,0.0,-1)) | (4,(D,17.0,2)) | (5,(E,5.0,3)) | (6,(F,13.0,5)) | (7,(G,14.0,5))
Prime:
(1,(A,7.0,2)) | (2,(B,7.0,5)) | (3,(C,0.0,-1)) | (4,(D,5.0,1)) | (5,(E,5.0,3)) | (6,(F,6.0,4)) | (7,(G,9.0,5))
FloydWarshall:
(4,Map(5 -> 14.0, 1 -> 5.0, 6 -> 6.0, 2 -> 9.0, 7 -> 17.0, 3 -> 17.0, 4 -> 0.0))
(2,Map(5 -> 7.0, 1 -> 7.0, 6 -> 15.0, 2 -> 0.0, 7 -> 16.0, 3 -> 8.0, 4 -> 9.0))
(7,Map(5 -> 9.0, 1 -> 22.0, 6 -> 11.0, 2 -> 16.0, 7 -> 0.0, 3 -> 14.0, 4 -> 17.0))
(5,Map(5 -> 0.0, 1 -> 14.0, 6 -> 8.0, 2 -> 7.0, 7 -> 9.0, 3 -> 5.0, 4 -> 14.0))
(3,Map(5 -> 5.0, 1 -> 15.0, 6 -> 13.0, 2 -> 8.0, 7 -> 14.0, 3 -> 0.0, 4 -> 17.0))
(1,Map(5 -> 14.0, 1 -> 0.0, 6 -> 11.0, 2 -> 7.0, 7 -> 22.0, 3 -> 15.0, 4 -> 5.0))
(6,Map(5 -> 8.0, 1 -> 11.0, 6 -> 0.0, 2 -> 15.0, 7 -> 11.0, 3 -> 13.0, 4 -> 6.0))

友情鏈接

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章