Spark GraphX 學習筆記——Dijstra最短路徑算法

1. Scala中的Dijstra最短路徑算法

import org.apache.spark.graphx._
def dijkstra[VD](g:Graph[VD,Double], origin:VertexId): Graph[(VD,Double), Double] = {
	
	/**
	 * 1. 初始化
	 * 遍歷圖的所有節點
	 * 變爲(false, Double.MaxValue的形式,後者是初始化的距離)
	 * 如果是origin節點,則變爲0
	 */
	var g2 = g.mapVertices(
		(vid,vd) => (false, if (vid == origin) 0 else Double.MaxValue))
	

	/**
	 * 2. 遍歷所有的點,找到最短路徑的點,並作爲當前頂點
	 */
	for (i <- 1L to g.vertices.count-1) {
	val currentVertexId =
		g2.vertices.filter(!_._2._1)
			.fold((0L,(false,Double.MaxValue)))((a,b) =>
				if (a._2._2 < b._2._2) a else b)
			._1

	// 3. 向與當前頂點相鄰的頂點發消息,再聚合消息:取小值作爲最短路徑
	val newDistances: VertexRDD[Double] = g2.aggregateMessages[Double](

		// sendMsg: 向鄰邊發送消息,內容爲邊的距離與最短路徑值之和
		ctx => if (ctx.srcId == currentVertexId)
			ctx.sendToDst(ctx.srcAttr._2 + ctx.attr),
		// mergeMsg: 選擇較小的值爲當前頂點的相鄰頂點的最短路徑值
		(a,b) => math.min(a,b))

	// 4. 生成結果圖
	g2 = g2.outerJoinVertices(newDistances)((vid, vd, newSum) =>
		(vd._1 || vid == currentVertexId,
		math.min(vd._2, newSum.getOrElse(Double.MaxValue))))
	}
	g.outerJoinVertices(g2.vertices)((vid, vd, dist) =>
		(vd, dist.getOrElse((false,Double.MaxValue))._2))
}


2. 執行最短路徑距離算法

val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"),(4L, "D"), (5L, "E"), (6L, "F"), (7L, "G")))

val myEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0),Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0),Edge(3L, 5L, 5.0), Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0),Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0), Edge(6L, 7L, 11.0)))

val myGraph = Graph(myVertices, myEdges)
dijkstra(myGraph, 1L).vertices.map(_._2).collect

輸出結果:
res0: Array[(String, Double)] = Array((D,5.0), (A,0.0), (F,11.0), (C,15.0), (G,22.0), (E,14.0), (B,7.0))


3. 包含路徑記錄的Dijkstra最短路徑算法
	在1的基礎上用一個List記錄尋找的路徑

import org.apache.spark.graphx._
def dijkstra[VD](g:Graph[VD,Double], origin:VertexId) = {
	var g2 = g.mapVertices(
		(vid,vd) => (false, if (vid == origin) 0 else Double.MaxValue,List[VertexId]()))

	for (i <- 1L to g.vertices.count-1) {
		val currentVertexId =
			g2.vertices.filter(!_._2._1)
				.fold((0L,(false,Double.MaxValue,List[VertexId]())))((a,b) =>
				if (a._2._2 < b._2._2) a else b)._1

		val newDistances = g2.aggregateMessages[(Double,List[VertexId])](
			ctx => if (ctx.srcId == currentVertexId)
				ctx.sendToDst((ctx.srcAttr._2 + ctx.attr,ctx.srcAttr._3 :+ ctx.srcId)),
			(a,b) => if (a._1 < b._1) a else b)
		g2 = g2.outerJoinVertices(newDistances)((vid, vd, newSum) => {
			val newSumVal = newSum.getOrElse((Double.MaxValue,List[VertexId]()))
			(vd._1 || vid == currentVertexId,
			math.min(vd._2, newSumVal._1),
			if (vd._2 < newSumVal._1) vd._3 else newSumVal._2)})
	}
	g.outerJoinVertices(g2.vertices)((vid, vd, dist) =>
		(vd, dist.getOrElse((false,Double.MaxValue,List[VertexId]())).productIterator.toList.tail))
}

4. 執行包含路徑記錄的Dijkstra最短路徑算法

dijkstra(myGraph, 1L).vertices.map(_._2).collect

	輸出結果:
	res1: Array[(String, List[Any])] = Array((D,List(5.0, List(1))), (A,List(0.0, List())), (F,List(11.0, List(1, 4))), (C,List(15.0, List(1, 2))), (G,List(22.0, List(1, 4, 6))), (E,List(14.0, List(1, 2))), (B,List(7.0, List(1))))

	結果解析:(G,List(22.0, List(1, 4, 6)))  1L到G的距離,分別經過1,4,6三個點,總距離爲22.0

參考書籍:Spark GraphX 實戰

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章