引入
圖算法指利用特製的線條算圖求得答案的一種簡便算法。無向圖、有向圖和網絡能運用很多常用的圖算法,這些算法包括:各種遍歷算法(這些遍歷類似於樹的遍歷),尋找最短路徑的算法,尋找網絡中最低代價路徑的算法,回答一些簡單相關問題(例如,圖是否是連通的,圖中兩個頂點間的最短路徑是什麼,等等)的算法。圖算法可應用到多種場合,例如:優化管道、路由表、快遞服務、通信網站等。
GraphFrames提供與GraphX相同的標準圖形算法套件以及一些新的算法。
目前,某些算法由GraphX的API實現的,因此在GraphFrames中可能沒有比GraphX更可擴展的功能。
目前,我們的業務涉及到企業知識圖譜,需要做路徑搜索、社區發現、標籤傳播等基於圖計算的應用,雖然neo4j也可以做,但是neo4j的分佈式版本價格很高。於是考慮使用spark做分佈式的圖計算。
本文不介紹太多算法細節,主要展示官網和實際案例的代碼實現。
廣度優先搜索
廣度優先搜索(Breadth-first search,簡稱BFS),是查找一個頂點到另外一個頂點的算法。
這裏是用pyspark自帶的friends數據集,實現路徑搜索。
我們先看看friends數據集長啥樣。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
g.vertices.show()
g.edges.show()
有7個節點代表7個人,然後用7個關係展示他們的人際關係,有friend和follow兩種關係。
+---+-------+---+
| id| name|age|
+---+-------+---+
| a| Alice| 34|
| b| Bob| 36|
| c|Charlie| 30|
| d| David| 29|
| e| Esther| 32|
| f| Fanny| 36|
+---+-------+---+
+---+---+------------+
|src|dst|relationship|
+---+---+------------+
| a| b| friend|
| b| c| follow|
| c| b| follow|
| f| c| follow|
| e| f| follow|
| e| d| friend|
| d| a| friend|
+---+---+------------+
然後用BFS的API來做路徑搜索,分別定義起點和終點的條件。
paths = g.bfs("name = 'Esther'", "age < 32")
paths.show()
+---------------+--------------+--------------+
| from| e0| to|
+---------------+--------------+--------------+
|[e, Esther, 32]|[e, d, friend]|[d, David, 29]|
+---------------+--------------+--------------+
可以看到,滿足節點名稱爲Esther的有兩條關係,但指向節點age小於32的只有David了,如圖所示。
另外還可以使用edgeFilter限制邊的條件,maxPathLength來定義鏈路的長度。
paths = g.bfs("name = 'Esther'", "age < 32", edgeFilter="relationship != 'friend'", maxPathLength=3)
paths.show()
+---------------+--------------+--------------+--------------+----------------+
| from| e0| v1| e1| to|
+---------------+--------------+--------------+--------------+----------------+
|[e, Esther, 32]|[e, f, follow]|[f, Fanny, 36]|[f, c, follow]|[c, Charlie, 30]|
+---------------+--------------+--------------+--------------+----------------+
同樣起始條件是Esther,maxPathLength長度爲3,且relationship必須不能是friend,找到下面的路徑。
連通分量
連通分量(Connected Components),基於搜索算法,計算節點和節點之間能否雙向抵達。
先看我們要分析的圖結構。
原始數據是這樣子:
people.csv
4,Dave,25
6,Faith,21
8,Harvey,47
2,Bob,18
1,Alice,20
3,Charlie,30
7,George,34
9,Ivy,21
5,Eve,30
10,Lily,35
11,Helen,35
12,Ann,35
links.csv
1,2,friend
1,3,sister
2,4,brother
3,2,boss
4,5,client
1,9,friend
6,7,cousin
7,9,coworker
8,9,father
10,11,colleague
10,12,colleague
11,12,colleague
先讀取csv數據,由於csv中沒有列名稱也無法指定列數據類型,我們使用withColumnRenamed和withColumn來操作,也可以封裝成函數方便後續使用。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
# g = Graphs(sqlContext).friends() # Get example graph
# 讀取數據
links_df = sqlContext.read.csv("links.csv")
# 修改列名稱
links_df = links_df.withColumnRenamed("_c0", "src")\
.withColumnRenamed("_c1", "dst")\
.withColumnRenamed("_c2", "relationship")
# 修改列類型
links_df = links_df.withColumn("src", links_df["src"].astype("int"))\
.withColumn("dst", links_df["dst"].astype("int"))
links_df.show()
# 讀取數據
nodes_df = sqlContext.read.csv("people.csv")
# 修改列名稱
nodes_df = nodes_df.withColumnRenamed("_c0", "id")\
.withColumnRenamed("_c1", "name")\
.withColumnRenamed("_c2", "age")
# 修改列類型
nodes_df = nodes_df.withColumn("id", nodes_df["id"].astype("int"))\
.withColumn("age", nodes_df["age"].astype("int"))
nodes_df.show()
+---+---+------------+
|src|dst|relationship|
+---+---+------------+
| 1| 2| friend|
...
+---+-------+---+
| id| name|age|
+---+-------+---+
| 4| Dave| 25|
...
然後使用內置Connected Components算法計算。
g = GraphFrame(nodes_df, links_df)
result = g.connectedComponents()
result.select("id", "component").orderBy("component").show()
輸出節點的分組信息:
+---+---------+
| id|component|
+---+---------+
| 2| 1|
| 4| 1|
| 8| 1|
| 7| 1|
| 9| 1|
| 5| 1|
| 1| 1|
| 6| 1|
| 3| 1|
| 10| 10|
| 11| 10|
| 12| 10|
+---+---------+
在圖中展示的話,就是對圖的可連通性做了標記。
強連通分量
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
result = g.stronglyConnectedComponents(maxIter=10)
result.select("id", "component").orderBy("component").show()
標籤傳播
標籤傳播算法(Label Propagation Algorithm,簡稱LPA),用來檢測網絡中的社區。
LPA不能保證會收斂,也可能會使每個節點都被識別爲一個社區,但是計算消耗的資源很低廉。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
result = g.labelPropagation(maxIter=5)
result.select("id", "label").show()
PageRank
PageRank的可以計算節點的權重排名。應用場景有很多,最著名的就是網頁排名,還有以下應用場景:
- Twitter:個性化PageRank算法用於向用戶推薦他們可能希望關注的賬號。
- 用於給公共場所和街道排名,同時預測街道和人類的活動趨勢。
- 在醫療和保險行業,作爲欺詐探測系統,幫助醫生或者供應商發現異常行爲,並將其作爲負樣本給機器學習算法進行訓練。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
results = g.pageRank(resetProbability=0.15, tol=0.01)
results.vertices.select("id", "pagerank").show()
results.edges.select("src", "dst", "weight").show()
results2 = g.pageRank(resetProbability=0.15, maxIter=10)
results3 = g.pageRank(resetProbability=0.15, maxIter=10, sourceId="a")
results4 = g.parallelPersonalizedPageRank(resetProbability=0.15, sourceIds=["a", "b", "c", "d"], maxIter=10)
最短路徑算法
最短路徑,顧名思義,兩點之間最短可到達方式。
shortestPaths計算的是圖中每個頂點到landmarks中給定的頂點的最短路徑。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
results = g.shortestPaths(landmarks=["a", "d"])
results.select("id", "distances").show()
+---+----------------+
| id| distances|
+---+----------------+
| b| []|
| e|[d -> 1, a -> 2]|
| a| [a -> 0]|
| f| []|
| d|[d -> 0, a -> 1]|
| c| []|
+---+----------------+
三角形計數
三角形計數(Triangle count)一般用來分析社交網絡。
通過Triangle Count能夠提供集羣的度,是進行聚類分析的重要依據和指標。
from pyspark import SparkContext
from pyspark.sql import SQLContext
from graphframes import GraphFrame
from graphframes.examples import Graphs
# spark
sc = SparkContext("local", appName="mysqltest")
sc.setCheckpointDir("./ccpoint")
sqlContext = SQLContext(sc)
g = Graphs(sqlContext).friends()
results = g.triangleCount()
results.select("id", "count").show()